Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
take EP context options from session
  • Loading branch information
gedoensmax committed May 1, 2025
commit 72772ea8030cc6ed1818d2e9ea4e7972ecf7fd91
73 changes: 34 additions & 39 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1068,45 +1068,40 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
}
};

// Get environment variables
if (info.has_trt_options) {
max_partition_iterations_ = info.max_partition_iterations;
min_subgraph_size_ = info.min_subgraph_size;
max_workspace_size_ = info.max_workspace_size;
dump_subgraphs_ = info.dump_subgraphs;
weight_stripped_engine_enable_ = info.weight_stripped_engine_enable;
onnx_model_folder_path_ = info.onnx_model_folder_path;
onnx_model_bytestream_ = info.onnx_bytestream;
onnx_model_bytestream_size_ = info.onnx_bytestream_size;
if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) ||
(onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"When providing either 'trt_onnx_bytestream_size' or "
"'trt_onnx_bytestream' both have to be provided"));
}
detailed_build_log_ = info.detailed_build_log;
dump_ep_context_model_ = info.dump_ep_context_model;
ep_context_file_path_ = info.ep_context_file_path;
ep_context_embed_mode_ = info.ep_context_embed_mode;
enable_engine_cache_for_ep_context_model();
cache_prefix_ = info.engine_cache_prefix;
// use a more global cache if given
engine_decryption_enable_ = info.engine_decryption_enable;
if (engine_decryption_enable_) {
engine_decryption_lib_path_ = info.engine_decryption_lib_path;
}
force_sequential_engine_build_ = info.force_sequential_engine_build;
context_memory_sharing_enable_ = info.context_memory_sharing_enable;
sparsity_enable_ = info.sparsity_enable;
auxiliary_streams_ = info.auxiliary_streams;
profile_min_shapes = info.profile_min_shapes;
profile_max_shapes = info.profile_max_shapes;
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
op_types_to_exclude_ = info.op_types_to_exclude;
} else {
LOGS_DEFAULT(INFO) << "[Nv EP] Options were not specified";
}
max_partition_iterations_ = info.max_partition_iterations;
min_subgraph_size_ = info.min_subgraph_size;
max_workspace_size_ = info.max_workspace_size;
dump_subgraphs_ = info.dump_subgraphs;
weight_stripped_engine_enable_ = info.weight_stripped_engine_enable;
onnx_model_folder_path_ = info.onnx_model_folder_path;
onnx_model_bytestream_ = info.onnx_bytestream;
onnx_model_bytestream_size_ = info.onnx_bytestream_size;
if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) ||
(onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"When providing either 'trt_onnx_bytestream_size' or "
"'trt_onnx_bytestream' both have to be provided"));
}
detailed_build_log_ = info.detailed_build_log;
dump_ep_context_model_ = info.dump_ep_context_model;
ep_context_file_path_ = info.ep_context_file_path;
ep_context_embed_mode_ = info.ep_context_embed_mode;
enable_engine_cache_for_ep_context_model();
cache_prefix_ = info.engine_cache_prefix;
// use a more global cache if given
engine_decryption_enable_ = info.engine_decryption_enable;
if (engine_decryption_enable_) {
engine_decryption_lib_path_ = info.engine_decryption_lib_path;
}
force_sequential_engine_build_ = info.force_sequential_engine_build;
context_memory_sharing_enable_ = info.context_memory_sharing_enable;
sparsity_enable_ = info.sparsity_enable;
auxiliary_streams_ = info.auxiliary_streams;
profile_min_shapes = info.profile_min_shapes;
profile_max_shapes = info.profile_max_shapes;
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
op_types_to_exclude_ = info.op_types_to_exclude;

// Validate setting
if (max_partition_iterations_ <= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
#include "core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h"
#include "core/providers/nv_tensorrt_rtx/nv_provider_options.h"

#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/common/make_string.h"
#include "core/common/parse_string.h"
#include "core/framework/provider_options_utils.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options,
const ConfigOptions& session_options) {
NvExecutionProviderInfo info{};
void* user_compute_stream = nullptr;
void* onnx_bytestream = nullptr;
Expand Down Expand Up @@ -58,6 +60,25 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi
info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);
info.onnx_bytestream = onnx_bytestream;

// EP context settings
const auto embed_enable = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0");
if (embed_enable == "0") {
info.dump_ep_context_model = false;
} else if (embed_enable == "1") {
info.dump_ep_context_model = true;
} else {
ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1");
}
info.ep_context_file_path = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");

const auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"));
if (0 <= embed_mode || embed_mode < 2) {
info.ep_context_embed_mode = embed_mode;
} else {
ORT_THROW("Invalid ", kOrtSessionOptionEpContextEmbedMode, " must 0 or 1");
}

return info;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#include "core/framework/ortdevice.h"
#include "core/framework/provider_options.h"
#include "core/framework/framework_provider_common.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/library_handles.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/providers/shared_library/provider_api.h"

#define TRT_DEFAULT_OPTIMIZER_LEVEL 3

Expand Down Expand Up @@ -62,7 +63,8 @@ struct NvExecutionProviderInfo {
bool engine_hw_compatible{false};
std::string op_types_to_exclude{""};

static NvExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static NvExecutionProviderInfo FromProviderOptions(const ProviderOptions& options,
const ConfigOptions& session_options);
static ProviderOptions ToProviderOptions(const NvExecutionProviderInfo& info);
std::vector<OrtCustomOpDomain*> custom_op_domain_list;
};
Expand Down
21 changes: 17 additions & 4 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::unique_ptr<IExecutionProvider> NvProviderFactory::CreateProvider(const OrtS
provider_options[key.substr(key_prefix.size())] = value;
}
}
NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(provider_options);
NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(provider_options, config_options);

auto ep = std::make_unique<NvExecutionProvider>(info);
ep->SetLogger(reinterpret_cast<const logging::Logger*>(&session_logger));
Expand All @@ -96,9 +96,22 @@ struct Nv_Provider : Provider {
return std::make_shared<NvProviderFactory>(info);
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* options) {
const ProviderOptions* provider_options = reinterpret_cast<const ProviderOptions*>(options);
NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(*provider_options);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* param) {
if (param == nullptr) {
LOGS_DEFAULT(ERROR) << "[NV EP] Passed NULL options to CreateExecutionProviderFactory()";
return nullptr;
}

std::array<const void*, 2> pointers_array = *reinterpret_cast<const std::array<const void*, 2>*>(param);
const ProviderOptions* provider_options = reinterpret_cast<const ProviderOptions*>(pointers_array[0]);
const ConfigOptions* config_options = reinterpret_cast<const ConfigOptions*>(pointers_array[1]);

if (provider_options == nullptr) {
LOGS_DEFAULT(ERROR) << "[NV EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()";
return nullptr;
}

NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(*provider_options, *config_options);
return std::make_shared<NvProviderFactory>(info);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace onnxruntime {
// defined in provider_bridge_ort.cc
struct NvProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options_map,
const SessionOptions* session_options);
};
} // namespace onnxruntime
11 changes: 9 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2049,8 +2049,15 @@ std::shared_ptr<IExecutionProviderFactory> NvProviderFactoryCreator::Create(int
}

std::shared_ptr<IExecutionProviderFactory> NvProviderFactoryCreator::Create(
const ProviderOptions& provider_options) try {
return s_library_nv.Get().CreateExecutionProviderFactory(&provider_options);
const ProviderOptions& provider_options,, const SessionOptions* session_options) try {
const ConfigOptions* config_options = nullptr;
if (session_options != nullptr) {
config_options = &session_options->config_options;
}

std::array<const void*, 2> configs_array = {&provider_options, config_options};
const void* arg = reinterpret_cast<const void*>(&configs_array);
return s_library_nv.Get().CreateExecutionProviderFactory(arg);
} catch (const std::exception& exception) {
// Will get an exception when fail to load EP library.
LOGS_DEFAULT(ERROR) << exception.what();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
}
case EpID::NvTensorRtRtx: {
#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE)
auto factory = onnxruntime::NvProviderFactoryCreator::Create(provider_options);
auto factory = onnxruntime::NvProviderFactoryCreator::Create(provider_options, &(options->value));
if (factory) {
options->provider_factories.push_back(factory);
} else {
Expand Down
51 changes: 31 additions & 20 deletions onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "test/common/trt_op_test_utils.h"

#include <onnxruntime_cxx_api.h>
#include <onnxruntime_run_options_config_keys.h>
#include <onnxruntime_session_options_config_keys.h>
#include <string>
#include <thread>
#include <filesystem>
Expand All @@ -27,6 +29,12 @@ std::string WideToUTF8(const std::wstring& wstr) {
return converter.to_bytes(wstr);
}

void clearFileIfExists(PathString path) {
if (std::filesystem::exists(path)) {
std::filesystem::remove(path);
}
}

template <typename T>
void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
const std::vector<T>& expected_values) {
Expand Down Expand Up @@ -74,7 +82,7 @@ void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64
* /
* "O"
*/
void CreateBaseModel(const PathString& model_name,
static void CreateBaseModel(const PathString& model_name,
std::string graph_name,
std::vector<int> dims,
bool add_fast_gelu = false) {
Expand Down Expand Up @@ -143,7 +151,7 @@ void CreateBaseModel(const PathString& model_name,
status = onnxruntime::Model::Save(model, model_name);
}

Ort::IoBinding generate_io_binding(Ort::Session& session, std::map<std::string, std::vector<int64_t>> shape_overwrites = {}) {
static Ort::IoBinding generate_io_binding(Ort::Session& session, std::map<std::string, std::vector<int64_t>> shape_overwrites = {}) {
Ort::IoBinding binding(session);
auto allocator = Ort::AllocatorWithDefaultOptions();
for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) {
Expand Down Expand Up @@ -178,6 +186,7 @@ Ort::IoBinding generate_io_binding(Ort::Session& session, std::map<std::string,
TEST(NvExecutionProviderTest, ContextEmbedAndReload) {
PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx");
PathString model_name_ctx = ORT_TSTR("nv_execution_provider_test_ctx.onnx");
clearFileIfExists(model_name_ctx);
std::string graph_name = "test";
std::vector<int> dims = {1, 3, 2};

Expand All @@ -192,9 +201,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) {
auto start = std::chrono::high_resolution_clock::now();
Ort::SessionOptions so;
Ort::RunOptions run_options;
so.AddConfigEntry("ep.context_enable", "1");
so.AddConfigEntry("ep.context_file_path", WideToUTF8(model_name_ctx).c_str());
so.AppendExecutionProvider("NvTensorRtRtx", {});
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, WideToUTF8(model_name_ctx).c_str());
so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
Ort::Session session_object(env, model_name.c_str(), so);
auto stop = std::chrono::high_resolution_clock::now();
std::cout << "Session creation AOT: " << std::chrono::duration_cast<std::chrono::milliseconds>((stop - start)).count() << " ms" << std::endl;
Expand All @@ -208,9 +217,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) {
auto start = std::chrono::high_resolution_clock::now();
Ort::SessionOptions so;
Ort::RunOptions run_options;
so.AddConfigEntry("ep.context_enable", "1");
so.AppendExecutionProvider("NvTensorRtRtx", {});
Ort::Session session_object(env, model_name.c_str(), so);
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
Ort::Session session_object(env, model_name_ctx.c_str(), so);
auto stop = std::chrono::high_resolution_clock::now();
std::cout << "Session creation JIT: " << std::chrono::duration_cast<std::chrono::milliseconds>((stop - start)).count() << " ms" << std::endl;

Expand All @@ -222,6 +231,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) {
TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) {
PathString model_name = ORT_TSTR("nv_execution_provider_dyn_test.onnx");
PathString model_name_ctx = ORT_TSTR("nv_execution_provider_dyn_test_ctx.onnx");
clearFileIfExists(model_name_ctx);
std::string graph_name = "test";
std::vector<int> dims = {1, -1, -1};

Expand All @@ -236,9 +246,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) {
auto start = std::chrono::high_resolution_clock::now();
Ort::SessionOptions so;
Ort::RunOptions run_options;
so.AddConfigEntry("ep.context_enable", "1");
so.AddConfigEntry("ep.context_file_path", WideToUTF8(model_name_ctx).c_str());
so.AppendExecutionProvider("NvTensorRtRtx", {});
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, WideToUTF8(model_name_ctx).c_str());
so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
Ort::Session session_object(env, model_name.c_str(), so);
auto stop = std::chrono::high_resolution_clock::now();
std::cout << "Session creation AOT: " << std::chrono::duration_cast<std::chrono::milliseconds>((stop - start)).count() << " ms" << std::endl;
Expand All @@ -252,9 +262,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) {
auto start = std::chrono::high_resolution_clock::now();
Ort::SessionOptions so;
Ort::RunOptions run_options;
so.AddConfigEntry("ep.context_enable", "1");
so.AppendExecutionProvider("NvTensorRtRtx", {});
Ort::Session session_object(env, model_name.c_str(), so);
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
Ort::Session session_object(env, model_name_ctx.c_str(), so);
auto stop = std::chrono::high_resolution_clock::now();
std::cout << "Session creation JIT: " << std::chrono::duration_cast<std::chrono::milliseconds>((stop - start)).count() << " ms" << std::endl;

Expand All @@ -269,6 +279,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) {
TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) {
PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx");
PathString model_name_ctx = ORT_TSTR("nv_execution_provider_data_dyn_test_ctx.onnx");
clearFileIfExists(model_name_ctx);
std::string graph_name = "test";
std::vector<int> dims = {1, -1, -1};

Expand All @@ -283,9 +294,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) {
auto start = std::chrono::high_resolution_clock::now();
Ort::SessionOptions so;
Ort::RunOptions run_options;
so.AddConfigEntry("ep.context_enable", "1");
so.AddConfigEntry("ep.context_file_path", WideToUTF8(model_name_ctx).c_str());
so.AppendExecutionProvider("NvTensorRtRtx", {});
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, WideToUTF8(model_name_ctx).c_str());
so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
Ort::Session session_object(env, model_name.c_str(), so);
auto stop = std::chrono::high_resolution_clock::now();
std::cout << "Session creation AOT: " << std::chrono::duration_cast<std::chrono::milliseconds>((stop - start)).count() << " ms" << std::endl;
Expand All @@ -299,9 +310,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) {
auto start = std::chrono::high_resolution_clock::now();
Ort::SessionOptions so;
Ort::RunOptions run_options;
so.AddConfigEntry("ep.context_enable", "1");
so.AppendExecutionProvider("NvTensorRtRtx", {});
Ort::Session session_object(env, model_name.c_str(), so);
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
Ort::Session session_object(env, model_name_ctx.c_str(), so);
auto stop = std::chrono::high_resolution_clock::now();
std::cout << "Session creation JIT: " << std::chrono::duration_cast<std::chrono::milliseconds>((stop - start)).count() << " ms" << std::endl;

Expand Down