Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Allow EpContext models with input/output models completely in buffers
  • Loading branch information
adrianlizarraga committed Apr 18, 2025
commit 984d24b0d8618678ab0c9a48057e0b58cbcd00d0
18 changes: 11 additions & 7 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,11 +800,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
return std::make_pair(false, static_cast<const Node*>(nullptr));
};

bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_allocator != nullptr;

std::filesystem::path context_cache_path;
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
graph.ModelPath(),
context_cache_path,
ep_context_gen_options.overwrite_existing_output_file));
if (!saving_to_buffer || !graph.ModelPath().empty()) {
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
graph.ModelPath(),
context_cache_path,
ep_context_gen_options.overwrite_existing_output_file));
}

Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
Expand Down Expand Up @@ -864,9 +870,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers

ModelSavingOptions model_saving_options{ini_size_threshold};

if (ep_context_gen_options.output_model_buffer_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
ep_context_gen_options.output_model_buffer_allocator != nullptr) {
if (saving_to_buffer) {
ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve());
// TODO(adrianlizarraga): Investigate if we can make this more memory efficient.
// May be able to use allocator to directly allocate the ModelProto to avoid a copy.
Expand Down
30 changes: 14 additions & 16 deletions onnxruntime/core/session/model_compilation_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ Status ModelCompilationOptions::ResetOutputModelSettings() {
return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "");
}

Status ModelCompilationOptions::CheckInputModelSettings() const {
Status ModelCompilationOptions::Check() const {
ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable);
const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty();
const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr;
const bool comes_from_file = !input_model_path_.empty();
const bool comes_from_memory = input_model_data_ != nullptr;

Expand All @@ -160,14 +164,15 @@ Status ModelCompilationOptions::CheckInputModelSettings() const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0");
}

return Status::OK();
}

Status ModelCompilationOptions::CheckOutputModelSettings() const {
const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;

const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty();
const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr;
if (comes_from_memory && writes_to_buffer && !ep_context_gen_options.embed_ep_context_in_model) {
// TODO(adrianlizarraga): We may want to support this in the future. That is, both input/output models
// are in buffers but the context cache binary is dumped to a file. Would need to allow user to specify
// a custom path for the context cache binary.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"EPContext embed mode must be true (enabled) when both the "
"input and output models are stored in buffers. "
"Please call ModelCompilationOptions_SetEpContextEmbedMode(true).");
}

if (!explicit_writes_to_file && !writes_to_buffer) {
// User did not specify an output file or an output buffer. We default to generating an output file
Expand All @@ -192,12 +197,5 @@ Status ModelCompilationOptions::CheckOutputModelSettings() const {

return Status::OK();
}

Status ModelCompilationOptions::Check() const {
ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable);
ORT_RETURN_IF_ERROR(CheckInputModelSettings());
ORT_RETURN_IF_ERROR(CheckOutputModelSettings());
return Status::OK();
}
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
2 changes: 0 additions & 2 deletions onnxruntime/core/session/model_compilation_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ class ModelCompilationOptions {
private:
void ResetInputModelSettings();
Status ResetOutputModelSettings();
Status CheckInputModelSettings() const;
Status CheckOutputModelSettings() const;

const OrtEnv& env_;
OrtSessionOptions session_options_;
Expand Down
17 changes: 11 additions & 6 deletions onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,

// If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file
if (options && model_path == nullptr) {
auto ep_context_enable = options->value.config_options.GetConfigEntry(kOrtSessionOptionEpContextEnable);
auto ep_context_file_path = options->value.config_options.GetConfigEntry(kOrtSessionOptionEpContextFilePath);
if (ep_context_enable.has_value() && ep_context_enable.value() == "1" && (!ep_context_file_path.has_value() || (ep_context_file_path.has_value() && ep_context_file_path.value().empty()))) {
EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions();

// This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case
// the user used the older SessionOptions' configuration entries to generate a compiled model.
if (ep_ctx_gen_options.enable &&
ep_ctx_gen_options.output_model_file_path.empty() &&
ep_ctx_gen_options.output_model_buffer_ptr == nullptr) {
return OrtApis::CreateStatus(ORT_FAIL,
"CreateSessionFromArray is called with ep.context_enable enabled but an \
empty ep.context_file_path. The system does not know where to generate the \
EP context model. Please specify a valid ep.context_file_path.");
"Inference session was configured with EPContext model generation enabled but "
"without a valid location (e.g., file or buffer) for the output model. "
"Please specify a valid ep.context_file_path via SessionOption configs "
"or use the OrtCompileApi to compile a model to a file or buffer.");
}
}

Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,48 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) {
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
}

// Test using the CompileModel() API with settings:
// - input model from buffer
// - save output model to buffer
// - EPContext nodes in output model use embedded binary blobs.
TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInBuffers_Embedded) {
// Create a test model and serialize it to a buffer.
TestModel test_model;
CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
std::string model_data = test_model.Serialize();

const ORTCHAR_T* output_model_file = ORT_TSTR("./qnn_context_binary_multi_partition_test.onnx");
std::filesystem::remove(output_model_file);

// Initialize session options with QNN EP
Ort::SessionOptions so;
ProviderOptions provider_options;
provider_options["backend_type"] = "htp";
provider_options["offload_graph_io_quantization"] = "0";
so.AppendExecutionProvider("QNN", provider_options);

// Create model compilation options from the session options.
Ort::ModelCompilationOptions compile_options(*ort_env, so);
compile_options.SetInputModelFromBuffer(reinterpret_cast<const void*>(model_data.data()), model_data.size());

Ort::AllocatorWithDefaultOptions allocator;
void* output_model_buffer = nullptr;
size_t output_model_buffer_size = 0;
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
compile_options.SetEpContextEmbedMode(true);

// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();

// Make sure the compiled model was saved to the buffer.
ASSERT_TRUE(output_model_buffer != nullptr);
ASSERT_TRUE(output_model_buffer_size > 0);

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
}

// Test using the CompileModel() API with settings:
// - input model from file
// - save output model to a buffer
Expand Down
Loading