Skip to content

Allow EpContext models with input/output models completely in buffers #24463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
@@ -800,11 +800,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
return std::make_pair(false, static_cast<const Node*>(nullptr));
};

bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_allocator != nullptr;

std::filesystem::path context_cache_path;
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
graph.ModelPath(),
context_cache_path,
ep_context_gen_options.overwrite_existing_output_file));
if (!saving_to_buffer || !graph.ModelPath().empty()) {
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
graph.ModelPath(),
context_cache_path,
ep_context_gen_options.overwrite_existing_output_file));
}

Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
@@ -864,9 +870,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers

ModelSavingOptions model_saving_options{ini_size_threshold};

if (ep_context_gen_options.output_model_buffer_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_allocator != nullptr) {
if (saving_to_buffer) {
ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve());
// TODO(adrianlizarraga): Investigate if we can make this more memory efficient.
// May be able to use allocator to directly allocate the ModelProto to avoid a copy.
Original file line number Diff line number Diff line change
@@ -247,6 +247,13 @@ Status CreateEPContextNodes(Model* model,
} else {
context_bin_path = context_model_path;
}

if (context_bin_path.empty()) {
// Context bin path is empty, so just use the graph name (e.g., "QNNExecutionProvider_QNN_13728744673520368385_2_0").
// This happens if both the input model and output model are stored in buffers (i.e., there are no paths).
context_bin_path = ToPathString(graph_name);
}

context_bin_path = context_bin_path + ToPathString("_qnn.bin");
context_cache_name = std::filesystem::path(context_bin_path).filename().string();

17 changes: 11 additions & 6 deletions onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
@@ -52,13 +52,18 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,

// If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file
if (options && model_path == nullptr) {
auto ep_context_enable = options->value.config_options.GetConfigEntry(kOrtSessionOptionEpContextEnable);
auto ep_context_file_path = options->value.config_options.GetConfigEntry(kOrtSessionOptionEpContextFilePath);
if (ep_context_enable.has_value() && ep_context_enable.value() == "1" && (!ep_context_file_path.has_value() || (ep_context_file_path.has_value() && ep_context_file_path.value().empty()))) {
EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions();

// This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case
// the user used the older SessionOptions' configuration entries to generate a compiled model.
if (ep_ctx_gen_options.enable &&
ep_ctx_gen_options.output_model_file_path.empty() &&
ep_ctx_gen_options.output_model_buffer_ptr == nullptr) {
return OrtApis::CreateStatus(ORT_FAIL,
"CreateSessionFromArray is called with ep.context_enable enabled but an \
empty ep.context_file_path. The system does not know where to generate the \
EP context model. Please specify a valid ep.context_file_path.");
"Inference session was configured with EPContext model generation enabled but "
"without a valid location (e.g., file or buffer) for the output model. "
"Please specify a valid ep.context_file_path via SessionOption configs "
"or use the OrtCompileApi to compile a model to a file or buffer.");
}
}

74 changes: 72 additions & 2 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
@@ -380,6 +380,75 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) {

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
allocator.Free(output_model_buffer);
}

// Test using the CompileModel() API with settings:
// - input model from buffer
// - save output model to buffer
// - test enabling AND disabling embed mode for context binary in EPContext node attributes
TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInBuffers) {
// Create a test model and serialize it to a buffer.
TestModel test_model;
CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
std::string model_data = test_model.Serialize();

// Initialize session options with QNN EP
Ort::SessionOptions session_options;
ProviderOptions provider_options;
provider_options["backend_type"] = "htp";
provider_options["offload_graph_io_quantization"] = "0";
session_options.AppendExecutionProvider("QNN", provider_options);

Ort::AllocatorWithDefaultOptions allocator;

// Test embed mode enabled.
{
void* output_model_buffer = nullptr;
size_t output_model_buffer_size = 0;

// Create model compilation options from the session options.
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
compile_options.SetInputModelFromBuffer(reinterpret_cast<const void*>(model_data.data()), model_data.size());
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
compile_options.SetEpContextEmbedMode(true);

// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();

// Make sure the compiled model was saved to the buffer.
ASSERT_TRUE(output_model_buffer != nullptr);
ASSERT_TRUE(output_model_buffer_size > 0);

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
allocator.Free(output_model_buffer);
}

// Test embed mode disabled.
{
void* output_model_buffer = nullptr;
size_t output_model_buffer_size = 0;

// Create model compilation options from the session options.
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
compile_options.SetInputModelFromBuffer(reinterpret_cast<const void*>(model_data.data()), model_data.size());
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
compile_options.SetEpContextEmbedMode(false);

// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();

// Make sure the compiled model was saved to the buffer.
ASSERT_TRUE(output_model_buffer != nullptr);
ASSERT_TRUE(output_model_buffer_size > 0);

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
allocator.Free(output_model_buffer);
}
}

// Test using the CompileModel() API with settings:
@@ -429,6 +498,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
allocator.Free(output_model_buffer);
}

// Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary
@@ -1510,7 +1580,7 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) {
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
std::string e_message1(std::string(e.what()));
ASSERT_TRUE(e_message1.find("Please specify a valid ep.context_file_path.") != std::string::npos);
ASSERT_TRUE(e_message1.find("Please specify a valid ep.context_file_path") != std::string::npos);
});
}

@@ -1521,7 +1591,7 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) {
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&ex]() {
std::string e_message2(std::string(ex.what()));
ASSERT_TRUE(e_message2.find("Please specify a valid ep.context_file_path.") != std::string::npos);
ASSERT_TRUE(e_message2.find("Please specify a valid ep.context_file_path") != std::string::npos);
});
}
}