Skip to content

Improve Shape Inference for GQA #24143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2025

Conversation

peishenyan
Copy link
Contributor

Description

For GroupQueryAttention op, if the input total_sequence_length is a constant, we can infer the shape of output present_key/present_value (batch_size, kv_num_heads, present_sequence_length, head_size).

int present_sequence_length = std::max(total_sequence_length, past_sequence_length);

We know that from CPU EP, present_sequence_length = max(past_sequence_length, total_sequence_length), and batch_size, kv_num_heads, head_size are the same as past_key/past_value.

This inference is very important for WebNN EP, because WebNN only supports GQA for present_sequence_length == past_sequence_length and requires static shape for graph compilation.

Motivation and Context

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,ONNX Runtime Web CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux QNN CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline,Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline

Copy link

Azure Pipelines successfully started running 8 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu tianleiwu merged commit c756e0a into microsoft:main Mar 28, 2025
95 of 101 checks passed
quic-zhaoxul pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Apr 17, 2025
### Description
<!-- Describe your changes. -->
For GroupQueryAttention op, if the input total_sequence_length is a
constant, we can infer the shape of output present_key/present_value
`(batch_size, kv_num_heads, present_sequence_length, head_size)`.


https://github.com/microsoft/onnxruntime/blob/5ed900e9712ce2f02e40c15b945d18453d1960d8/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h#L185

We know that from CPU EP, `present_sequence_length =
max(past_sequence_length, total_sequence_length)`, and `batch_size,
kv_num_heads, head_size` are the same as past_key/past_value.

This inference is very important for WebNN EP, because WebNN only
supports GQA for `present_sequence_length == past_sequence_length` and
requires static shape for graph compilation.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants