Skip to content

Add python bindings to the global thread pool functionality #24238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
@@ -55,6 +55,7 @@
register_execution_provider_library, # noqa: F401
set_default_logger_severity, # noqa: F401
set_default_logger_verbosity, # noqa: F401
set_global_thread_pool_sizes, # noqa: F401
set_seed, # noqa: F401
unregister_execution_provider_library, # noqa: F401
)
1 change: 1 addition & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
@@ -486,6 +486,7 @@ def __init__(
raise e

def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
C.ensure_env_initialized()
available_providers = C.get_available_providers()

# Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
80 changes: 72 additions & 8 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
#include "core/session/onnxruntime_cxx_api.h"

#include "core/session/lora_adapters.h"

@@ -75,6 +76,7 @@ const OrtDevice::DeviceType OrtDevice::GPU;

#include <iterator>
#include <algorithm>
#include <utility>

namespace onnxruntime {
namespace python {
@@ -1614,6 +1616,15 @@ static void LogDeprecationWarning(
#endif

void addGlobalMethods(py::module& m) {
m.def("set_global_thread_pool_sizes", [](int intra_op_num_threads, int inter_op_num_threads) {
OrtThreadingOptions to;
to.intra_op_thread_pool_params.thread_pool_size = intra_op_num_threads;
to.inter_op_thread_pool_params.thread_pool_size = inter_op_num_threads;
SetGlobalThreadingOptions(to); },
py::arg("intra_op_num_threads") = 0, // Default value for intra_op_num_threads
py::arg("inter_op_num_threads") = 0, // Default value for inter_op_num_threads
"Set the number of threads used by the global thread pools for intra and inter op parallelism.");
m.def("ensure_env_initialized", []() { GetEnv(); }, "Ensure the onnxruntime environment is initialized.");
m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance.");
m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer.");
m.def(
@@ -2070,7 +2081,7 @@ for model inference.)pbdoc");
ORT_THROW("OrtEpDevices are not supported in this build");
#endif
},
R"pbdoc(Adds the execution provider that is responsible for the selected OrtEpDevice instances. All OrtEpDevice instances
R"pbdoc(Adds the execution provider that is responsible for the selected OrtEpDevice instances. All OrtEpDevice instances
must refer to the same execution provider.)pbdoc")
.def(
// Equivalent to the C API's SessionOptionsSetEpSelectionPolicy.
@@ -2209,6 +2220,13 @@ Serialized model format will default to ONNX unless:
},
R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0.
Applies to session load, initialization, etc. Default is 0.)pbdoc")
.def_property(
"use_per_session_threads",
[](const PySessionOptions* options) -> bool { return options->value.use_per_session_threads; },
[](PySessionOptions* options, bool use_per_session_threads) -> void {
options->value.use_per_session_threads = use_per_session_threads;
},
R"pbdoc(Whether to use per-session thread pool. Default is True.)pbdoc")
.def_property(
"intra_op_num_threads",
[](const PySessionOptions* options) -> int { return options->value.intra_op_param.thread_pool_size; },
@@ -2486,6 +2504,14 @@ including arg name, arg type (contains both type and shape).)pbdoc")
auto env = GetEnv();
std::unique_ptr<PyInferenceSession> sess;

if (CheckIfUsingGlobalThreadPool() && so.value.use_per_session_threads) {
ORT_THROW("use_per_session_threads must be false when using a global thread pool");
}

if (so.value.intra_op_param.thread_pool_size != 0 || so.value.inter_op_param.thread_pool_size != 0) {
LOGS_DEFAULT(WARNING) << "session options intra_op_param.thread_pool_size and inter_op_param.thread_pool_size are ignored when using a global thread pool";
}

// separate creation of the session from model loading unless we have to read the config from the model.
// in a minimal build we only support load via Load(...) and not at session creation time
if (load_config_from_model) {
@@ -2857,8 +2883,6 @@ bool CreateInferencePybindStateModule(py::module& m) {

import_array1(false);

auto env = GetEnv();

addGlobalMethods(m);
addObjectMethods(m, RegisterExecutionProviders);
addOrtValueMethods(m);
@@ -2914,6 +2938,15 @@ namespace {
// For all the related details and why it is needed see "Modern C++ design" by A. Alexandrescu Chapter 6.
class EnvInitializer {
public:
static void SetGlobalThreadingOptions(const OrtThreadingOptions& new_tp_options) {
if (EnvInitializer::initialized) {
ORT_THROW("Cannot set global threading options after the environment has been initialized.");
}

EnvInitializer::tp_options = new_tp_options;
EnvInitializer::use_per_session_threads = false;
}

static std::shared_ptr<onnxruntime::Environment> SharedInstance() {
// Guard against attempts to resurrect the singleton
if (EnvInitializer::destroyed) {
@@ -2923,16 +2956,33 @@ class EnvInitializer {
return env_holder.Get();
}

static bool GetUsePerSessionThreads() {
return use_per_session_threads;
}

private:
EnvInitializer() {
std::unique_ptr<Environment> env_ptr;
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::make_unique<CLogSink>(),
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
env_ptr));

// create logging manager here
std::unique_ptr<LoggingManager> lm = std::make_unique<LoggingManager>(
std::make_unique<CLogSink>(),
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id);

if (EnvInitializer::use_per_session_threads) {
OrtPybindThrowIfError(Environment::Create(std::move(lm),
env_ptr));
} else {
OrtPybindThrowIfError(Environment::Create(std::move(lm),
env_ptr,
&EnvInitializer::tp_options,
true));
}

session_env_ = std::shared_ptr<Environment>(env_ptr.release());
initialized = true;
destroyed = false;
}

@@ -2946,12 +2996,26 @@ class EnvInitializer {

std::shared_ptr<Environment> session_env_;

static OrtThreadingOptions tp_options;
static bool use_per_session_threads;
static bool initialized;
static bool destroyed;
};

OrtThreadingOptions EnvInitializer::tp_options;
bool EnvInitializer::use_per_session_threads = true;
bool EnvInitializer::initialized = false;
bool EnvInitializer::destroyed = false;
} // namespace

void SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options) {
EnvInitializer::SetGlobalThreadingOptions(tp_options);
}

bool CheckIfUsingGlobalThreadPool() {
return !EnvInitializer::GetUsePerSessionThreads();
}

std::shared_ptr<onnxruntime::Environment> GetEnv() {
return EnvInitializer::SharedInstance();
}
3 changes: 3 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state_common.h
Original file line number Diff line number Diff line change
@@ -421,6 +421,9 @@ class SessionObjectInitializer {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif

void SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options);
bool CheckIfUsingGlobalThreadPool();
std::shared_ptr<Environment> GetEnv();

// Initialize an InferenceSession.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# pylint: disable=C0115,W0212,C0103,C0114
import unittest

import numpy as np
from helper import get_name

import onnxruntime as onnxrt


class TestGlobalThreadPool(unittest.TestCase):
def test_global_threadpool(self):
onnxrt.set_global_thread_pool_sizes(2, 2)
session_opts = onnxrt.SessionOptions()
session_opts.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL
session_opts.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL
session_opts.use_per_session_threads = False
session = onnxrt.InferenceSession(
get_name("mnist.onnx"), session_opts, providers=onnxrt.get_available_providers()
)
input = np.ones([1, 1, 28, 28], np.float32)
session.run(None, {"Input3": input})


if __name__ == "__main__":
unittest.main()
2 changes: 0 additions & 2 deletions orttraining/orttraining/python/orttraining_python_module.cc
Original file line number Diff line number Diff line change
@@ -310,8 +310,6 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
m.doc() = "pybind11 stateful interface to ORTTraining";
RegisterExceptions(m);

// Instantiate singletons
GetTrainingEnv();
addGlobalMethods(m);
addObjectMethods(m, ORTTrainingRegisterExecutionProviders);
addOrtValueMethods(m);
3 changes: 3 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
@@ -1720,6 +1720,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
[sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path
)

log.info("Testing Global Thread Pool feature")
run_subprocess([sys.executable, "onnxruntime_test_python_global_threadpool.py"], cwd=cwd, dll_path=dll_path)

log.info("Testing AutoEP feature")
run_subprocess([sys.executable, "onnxruntime_test_python_autoep.py"], cwd=cwd, dll_path=dll_path)

Loading