29
29
30
30
namespace onnxruntime {
31
31
32
- #if !defined(ORT_MINIMAL_BUILD)
33
- #define BATCHNORM_INCLUDE_TRAINING_SUPPORT
34
- #endif
35
-
36
32
template <typename T>
37
33
class BatchNorm : public OpKernel {
38
34
public:
@@ -51,7 +47,7 @@ class BatchNorm : public OpKernel {
51
47
}
52
48
53
49
if (is_train_) {
54
- #if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
50
+ #ifdef ENABLE_TRAINING_OPS
55
51
momentum_ = op_kernel_info.GetAttrOrDefault <float >(" momentum" , 0 .9f );
56
52
ORT_ENFORCE (is_spatial_, " Training mode only supports spatial BN" );
57
53
#else
@@ -88,7 +84,7 @@ class BatchNorm : public OpKernel {
88
84
// calculate sample_size (including all channels)
89
85
size_t sample_size_incl_all_channels = sample_size * C;
90
86
91
- #if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
87
+ #ifdef ENABLE_TRAINING_OPS
92
88
AllocatorPtr alloc;
93
89
ORT_RETURN_IF_ERROR (p_op_kernel_context->GetTempSpaceAllocator (&alloc));
94
90
@@ -115,7 +111,7 @@ class BatchNorm : public OpKernel {
115
111
ConstEigenVectorArrayMap<T> scale_arr (scale->Data <T>(), is_spatial_ ? C : sample_size_incl_all_channels);
116
112
ConstEigenVectorArrayMap<T> bias_arr (B->Data <T>(), is_spatial_ ? C : sample_size_incl_all_channels);
117
113
118
- #if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
114
+ #ifdef ENABLE_TRAINING_OPS
119
115
// Note that we only support spatial BN for training
120
116
if (is_train_) {
121
117
EigenVectorArrayMap<T> saved_mean_arr (saved_mean->MutableData <T>(), C);
@@ -166,7 +162,7 @@ class BatchNorm : public OpKernel {
166
162
ConstEigenVectorArrayMap<T> var_arr (var->Data <T>(), is_spatial_ ? C : sample_size_incl_all_channels);
167
163
inv_std = (var_arr + epsilon_).sqrt ().inverse ();
168
164
} else {
169
- #if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
165
+ #ifdef ENABLE_TRAINING_OPS
170
166
EigenVectorArrayMap<T> saved_inv_std_arr (saved_inv_std->MutableData <T>(), C);
171
167
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse ().sqrt ();
172
168
inv_std = saved_inv_std_arr;
@@ -175,7 +171,7 @@ class BatchNorm : public OpKernel {
175
171
176
172
// If we're training, do batch normalization based on computation from this batch
177
173
ConstEigenVectorArrayMap<T> mean_arr (
178
- #if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT)
174
+ #ifdef ENABLE_TRAINING_OPS
179
175
!is_train_ ? mean->Data <T>() : saved_mean->Data <T>(),
180
176
#else
181
177
mean->Data <T>(),
0 commit comments