Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e03631e

Browse files
authoredMar 24, 2025··
[webgpu] add option to perserve device and enable in unittest (#24115)
### Description This PR introduced a new WebGPU EP option `preserveDevice`. Before this change, a WebGPU device will be destroyed when no inference session uses it. The destroy of a WebGPU device will cleanup both buffer cache and shader cache. After this option is introduced, when the option is ON (default value is OFF), the device will no longer be destroyed and will be always keep alive. This is helpful in 2 scenarios: - A server that will be always on - unittest so that bugs of incorrect shader cache may be detected. (thanks to @qjia7 for the suggestion)
1 parent 5244d68 commit e03631e

File tree

7 files changed

+37
-6
lines changed

7 files changed

+37
-6
lines changed
 

‎onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con
100100

101101
program
102102
.CacheHint(interleaved_)
103-
.AddInputs({{input, ProgramTensorMetadataDependency::Rank},
103+
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
104104
{position_ids, ProgramTensorMetadataDependency::Rank},
105105
{cos_cache, ProgramTensorMetadataDependency::Rank},
106106
{sin_cache, ProgramTensorMetadataDependency::Rank}})

‎onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,13 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
247247
program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank});
248248
}
249249

250-
program.CacheHint(is_input_empty)
250+
// TODO: the ReduceKernel class is designed to use `keepdims_`, `noop_with_empty_axes_` and input axes as uniform variables,
251+
// but the current implementation does not work without them in cache key.
252+
// This is a temporary workaround to make it work. We should fix this in the future.
253+
program.CacheHint(keepdims_,
254+
noop_with_empty_axes_,
255+
select_last_index_,
256+
absl::StrJoin(input_axes, ","))
251257
.AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank})
252258
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
253259
.AddUniformVariables({{static_cast<uint32_t>(output_size)},

‎onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
768768
auto it = contexts_.find(context_id);
769769
if (it == contexts_.end()) {
770770
GSL_SUPPRESS(r.11)
771-
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
771+
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode, config.preserve_device));
772772
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
773773
} else if (context_id != 0) {
774774
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
@@ -794,7 +794,7 @@ void WebGpuContextFactory::ReleaseContext(int context_id) {
794794
auto it = contexts_.find(context_id);
795795
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");
796796

797-
if (--it->second.ref_count == 0) {
797+
if (--it->second.ref_count == 0 && !it->second.context->preserve_device_) {
798798
contexts_.erase(it);
799799
}
800800
}

‎onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct WebGpuContextConfig {
3232
WGPUDevice device;
3333
const void* dawn_proc_table;
3434
ValidationMode validation_mode;
35+
bool preserve_device;
3536
};
3637

3738
struct WebGpuBufferCacheConfig {
@@ -152,8 +153,8 @@ class WebGpuContext final {
152153
AtPasses
153154
};
154155

155-
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
156-
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
156+
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode, bool preserve_device)
157+
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device} {}
157158
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);
158159

159160
std::vector<const char*> GetEnabledAdapterToggles() const;
@@ -229,6 +230,7 @@ class WebGpuContext final {
229230

230231
uint64_t gpu_timestamp_offset_ = 0;
231232
bool is_profiling_ = false;
233+
bool preserve_device_;
232234

233235
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
234236
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;

‎onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,33 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
143143
}
144144
}
145145

146+
std::string preserve_device_str;
147+
bool preserve_device = false;
148+
if (config_options.TryGetConfigEntry(kPreserveDevice, preserve_device_str)) {
149+
if (preserve_device_str == kPreserveDevice_ON) {
150+
preserve_device = true;
151+
} else if (preserve_device_str == kPreserveDevice_OFF) {
152+
preserve_device = false;
153+
} else {
154+
ORT_THROW("Invalid preserve device: ", preserve_device_str);
155+
}
156+
}
157+
146158
webgpu::WebGpuContextConfig context_config{
147159
context_id,
148160
reinterpret_cast<WGPUInstance>(webgpu_instance),
149161
reinterpret_cast<WGPUDevice>(webgpu_device),
150162
reinterpret_cast<const void*>(dawn_proc_table),
151163
validation_mode,
164+
preserve_device,
152165
};
153166

154167
LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id;
155168
LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance;
156169
LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device;
157170
LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table;
158171
LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode;
172+
LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << preserve_device;
159173

160174
//
161175
// STEP.3 - prepare parameters for WebGPU context initialization.

‎onnxruntime/core/providers/webgpu/webgpu_provider_options.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ constexpr const char* kValidationMode = "WebGPU:validationMode";
3030
constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames";
3131
constexpr const char* kEnablePIXCapture = "WebGPU:enablePIXCapture";
3232

33+
constexpr const char* kPreserveDevice = "WebGPU:preserveDevice";
34+
3335
// The following are the possible values for the provider options.
3436

3537
constexpr const char* kDawnBackendType_D3D12 = "D3D12";
@@ -44,6 +46,9 @@ constexpr const char* kEnableGraphCapture_OFF = "0";
4446
constexpr const char* kEnablePIXCapture_ON = "1";
4547
constexpr const char* kEnablePIXCapture_OFF = "0";
4648

49+
constexpr const char* kPreserveDevice_ON = "1";
50+
constexpr const char* kPreserveDevice_OFF = "0";
51+
4752
constexpr const char* kBufferCacheMode_Disabled = "disabled";
4853
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
4954
constexpr const char* kBufferCacheMode_Simple = "simple";

‎onnxruntime/test/util/default_providers.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ std::unique_ptr<IExecutionProvider> DefaultWebGpuExecutionProvider() {
303303
ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode,
304304
webgpu::options::kBufferCacheMode_Disabled)
305305
.IsOK());
306+
// Disable device auto collect
307+
ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreserveDevice,
308+
webgpu::options::kPreserveDevice_ON)
309+
.IsOK());
306310
return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider();
307311
#else
308312
return nullptr;

0 commit comments

Comments
 (0)
Please sign in to comment.