Skip to content

[js/web] Add Wasm Relaxed SIMD support to wasm backend #22794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -193,6 +193,7 @@ option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on
option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD "Enable WebAssembly Relaxed SIMD" OFF)

# Enable bitcode for iOS
option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF)
5 changes: 4 additions & 1 deletion cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
@@ -35,7 +35,10 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
set(CMAKE_CXX_FLAGS_DEBUG "-g2")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
if (onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD)
string(APPEND CMAKE_C_FLAGS " -msimd128 -mrelaxed-simd")
string(APPEND CMAKE_CXX_FLAGS " -msimd128 -mrelaxed-simd")
elseif (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
string(APPEND CMAKE_C_FLAGS " -msimd128")
string(APPEND CMAKE_CXX_FLAGS " -msimd128")
endif()
7 changes: 6 additions & 1 deletion cmake/external/xnnpack.cmake
Original file line number Diff line number Diff line change
@@ -143,7 +143,12 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c)
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c)

if(onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
if(onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD)
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c)
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmrelaxedsimd.c)
target_compile_options(XNNPACK PRIVATE "-msimd128")
target_compile_options(XNNPACK PRIVATE "-mrelaxed-simd")
elseif(onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c)
target_compile_options(XNNPACK PRIVATE "-msimd128")
endif()
6 changes: 6 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
@@ -287,6 +287,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
${mlas_platform_srcs}
${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp
)
if (onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD)
set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/qgemm_kernel_wasmrelaxedsimd.cpp
)
endif()
else()
file(GLOB_RECURSE mlas_platform_srcs
"${MLAS_SRC_DIR}/scalar/*.cpp"
5 changes: 4 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
@@ -222,8 +222,11 @@ function(AddTest)
else()
set(TEST_NODE_FLAGS)

if (onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD)
message(WARNING "Use system `node` to test Wasm relaxed SIMD. Please make sure to install node v21 or newer.")
set(NODE_EXECUTABLE node)
# prefer Node from emsdk so the version is more deterministic
if (DEFINED ENV{EMSDK_NODE})
elseif (DEFINED ENV{EMSDK_NODE})
set(NODE_EXECUTABLE $ENV{EMSDK_NODE})
else()
message(WARNING "EMSDK_NODE environment variable was not set. Falling back to system `node`.")
4 changes: 3 additions & 1 deletion cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
@@ -485,7 +485,9 @@ jsepDownload:_pp_")

list(APPEND target_name_list "wasm")

if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
if (onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD)
list(APPEND target_name_list "relaxedsimd")
elseif (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
list(APPEND target_name_list "simd")
endif()

5 changes: 4 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
@@ -63,7 +63,10 @@ Module Name:
#endif
#if defined(__wasm__)
#define MLAS_TARGET_WASM
#if defined(__wasm_simd128__)
#if defined(__wasm_relaxed_simd__)
#define MLAS_TARGET_WASM_RELAXED_SIMD
#define MLAS_TARGET_WASM_SIMD
#elif defined(__wasm_simd128__)
#define MLAS_TARGET_WASM_SIMD
#else
#define MLAS_TARGET_WASM_SCALAR
5 changes: 5 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
@@ -996,9 +996,14 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10;

#if defined(MLAS_TARGET_WASM_RELAXED_SIMD)
extern bool HasUSDot();
#endif

//
// Symmetric quantized qgemm dispatch structure
//
8 changes: 8 additions & 0 deletions onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
@@ -886,6 +886,14 @@ MlasGemmQuantGetDispatch(
if(BIsSigned || !AIsSigned) {
GemmQuantDispatch = &MlasGemmU8X8DispatchNeon;
}
#elif defined(MLAS_TARGET_WASM_RELAXED_SIMD)
if (!AIsSigned) {
if (HasUSDot()) {
GemmQuantDispatch = &MlasGemmU8X8DispatchWasmRelaxedSimd;
} else {
GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd;
}
}
#elif defined(MLAS_TARGET_WASM_SIMD)
if (!AIsSigned) {
GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd;
Loading