Skip to content

Commit 54153c7

Browse files
authoredAug 11, 2023
Batchnorm training mode support in a minimal build (#17103)
1 parent 344c41f commit 54153c7

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed
 

Diff for: ‎onnxruntime/core/providers/cpu/nn/batch_norm.h

+5-9
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929

3030
namespace onnxruntime {
3131

32-
#if !defined(ORT_MINIMAL_BUILD)
33-
#define BATCHNORM_INCLUDE_TRAINING_SUPPORT
34-
#endif
35-
3632
template <typename T>
3733
class BatchNorm : public OpKernel {
3834
public:
@@ -51,7 +47,7 @@ class BatchNorm : public OpKernel {
5147
}
5248

5349
if (is_train_) {
54-
#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
50+
#ifdef ENABLE_TRAINING_OPS
5551
momentum_ = op_kernel_info.GetAttrOrDefault<float>("momentum", 0.9f);
5652
ORT_ENFORCE(is_spatial_, "Training mode only supports spatial BN");
5753
#else
@@ -88,7 +84,7 @@ class BatchNorm : public OpKernel {
8884
// calculate sample_size (including all channels)
8985
size_t sample_size_incl_all_channels = sample_size * C;
9086

91-
#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
87+
#ifdef ENABLE_TRAINING_OPS
9288
AllocatorPtr alloc;
9389
ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc));
9490

@@ -115,7 +111,7 @@ class BatchNorm : public OpKernel {
115111
ConstEigenVectorArrayMap<T> scale_arr(scale->Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
116112
ConstEigenVectorArrayMap<T> bias_arr(B->Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
117113

118-
#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
114+
#ifdef ENABLE_TRAINING_OPS
119115
// Note that we only support spatial BN for training
120116
if (is_train_) {
121117
EigenVectorArrayMap<T> saved_mean_arr(saved_mean->MutableData<T>(), C);
@@ -166,7 +162,7 @@ class BatchNorm : public OpKernel {
166162
ConstEigenVectorArrayMap<T> var_arr(var->Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
167163
inv_std = (var_arr + epsilon_).sqrt().inverse();
168164
} else {
169-
#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
165+
#ifdef ENABLE_TRAINING_OPS
170166
EigenVectorArrayMap<T> saved_inv_std_arr(saved_inv_std->MutableData<T>(), C);
171167
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt();
172168
inv_std = saved_inv_std_arr;
@@ -175,7 +171,7 @@ class BatchNorm : public OpKernel {
175171

176172
// If we're training, do batch normalization based on computation from this batch
177173
ConstEigenVectorArrayMap<T> mean_arr(
178-
#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
174+
#ifdef ENABLE_TRAINING_OPS
179175
!is_train_ ? mean->Data<T>() : saved_mean->Data<T>(),
180176
#else
181177
mean->Data<T>(),

Diff for: ‎onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the MIT License.
33

44
#include "core/framework/tensor.h"
5-
#include "core/providers/cpu/nn/batch_norm.h" // for BATCHNORM_INCLUDE_TRAINING_SUPPORT
65
#include "core/session/inference_session.h"
76
#include "test/common/dnnl_op_test_utils.h"
87
#include "test/providers/provider_test_utils.h"
@@ -847,7 +846,7 @@ TEST(BatchNormTest, BatchNorm2d_bfloat16) {
847846
#endif // USE_DNNL
848847

849848
// TODO fix flaky test for CUDA
850-
#ifdef BATCHNORM_INCLUDE_TRAINING_SUPPORT
849+
#ifdef ENABLE_TRAINING_OPS
851850
TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
852851
// TODO: Unskip when fixed #41968513
853852
if (DefaultDmlExecutionProvider().get() != nullptr) {
@@ -937,7 +936,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
937936
{kCudaExecutionProvider, kRocmExecutionProvider,
938937
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
939938
}
940-
#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT
939+
#endif // ENABLE_TRAINING_OPS
941940

942941
} // namespace test
943942
} // namespace onnxruntime

0 commit comments

Comments
 (0)
Please sign in to comment.