-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Improve Shape Inference for GQA #24143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,ONNX Runtime Web CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline |
/azp run Linux QNN CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline,Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline |
Azure Pipelines successfully started running 8 pipeline(s). |
Azure Pipelines successfully started running 10 pipeline(s). |
### Description <!-- Describe your changes. --> For GroupQueryAttention op, if the input total_sequence_length is a constant, we can infer the shape of output present_key/present_value `(batch_size, kv_num_heads, present_sequence_length, head_size)`. https://github.com/microsoft/onnxruntime/blob/5ed900e9712ce2f02e40c15b945d18453d1960d8/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h#L185 We know that from CPU EP, `present_sequence_length = max(past_sequence_length, total_sequence_length)`, and `batch_size, kv_num_heads, head_size` are the same as past_key/past_value. This inference is very important for WebNN EP, because WebNN only supports GQA for `present_sequence_length == past_sequence_length` and requires static shape for graph compilation. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Description
For GroupQueryAttention op, if the input total_sequence_length is a constant, we can infer the shape of output present_key/present_value
(batch_size, kv_num_heads, present_sequence_length, head_size)
.onnxruntime/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Line 185 in 5ed900e
We know that from CPU EP,
present_sequence_length = max(past_sequence_length, total_sequence_length)
, andbatch_size, kv_num_heads, head_size
are the same as past_key/past_value.This inference is very important for WebNN EP, because WebNN only supports GQA for
present_sequence_length == past_sequence_length
and requires static shape for graph compilation.Motivation and Context