-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[WebGPU] Optimize GEMM with vec4 #24478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU] Optimize GEMM with vec4 #24478
Conversation
Is there a way to reuse the implementation of |
Thanks for the callout. That's what I'm gonna do next. |
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
…ize_vec4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks.
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,ONNX Runtime Web CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline |
/azp run Linux QNN CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline,Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Azure Pipelines successfully started running 6 pipeline(s), but failed to run 1 pipeline(s). |
Description
In this PR, we use vec4 to optimize GEMM when colums of A and B can be divided by 4, or use previous shader.
I will add u32/vec2 implementation in the future, and we will only keep one shader at that time.
Perf comparison
I run customized model only include GEMM(M = N = K = 1024) with nodejs on M2/M3 Max. Roughly 20% increase.