@@ -47,7 +47,7 @@ class BatchNorm : public OpKernel {
47
47
}
48
48
49
49
if (is_train_) {
50
- #ifdef ENABLE_TRAINING_CORE
50
+ #ifdef ENABLE_TRAINING_OPS
51
51
momentum_ = op_kernel_info.GetAttrOrDefault <float >(" momentum" , 0 .9f );
52
52
ORT_ENFORCE (is_spatial_, " Training mode only supports spatial BN" );
53
53
#else
@@ -84,7 +84,7 @@ class BatchNorm : public OpKernel {
84
84
// calculate sample_size (including all channels)
85
85
size_t sample_size_incl_all_channels = sample_size * C;
86
86
87
- #ifdef ENABLE_TRAINING_CORE
87
+ #ifdef ENABLE_TRAINING_OPS
88
88
AllocatorPtr alloc;
89
89
ORT_RETURN_IF_ERROR (p_op_kernel_context->GetTempSpaceAllocator (&alloc));
90
90
@@ -111,7 +111,7 @@ class BatchNorm : public OpKernel {
111
111
ConstEigenVectorArrayMap<T> scale_arr (scale->Data <T>(), is_spatial_ ? C : sample_size_incl_all_channels);
112
112
ConstEigenVectorArrayMap<T> bias_arr (B->Data <T>(), is_spatial_ ? C : sample_size_incl_all_channels);
113
113
114
- #ifdef ENABLE_TRAINING_CORE
114
+ #ifdef ENABLE_TRAINING_OPS
115
115
// Note that we only support spatial BN for training
116
116
if (is_train_) {
117
117
EigenVectorArrayMap<T> saved_mean_arr (saved_mean->MutableData <T>(), C);
@@ -162,7 +162,7 @@ class BatchNorm : public OpKernel {
162
162
ConstEigenVectorArrayMap<T> var_arr (var->Data <T>(), is_spatial_ ? C : sample_size_incl_all_channels);
163
163
inv_std = (var_arr + epsilon_).sqrt ().inverse ();
164
164
} else {
165
- #ifdef ENABLE_TRAINING_CORE
165
+ #ifdef ENABLE_TRAINING_OPS
166
166
EigenVectorArrayMap<T> saved_inv_std_arr (saved_inv_std->MutableData <T>(), C);
167
167
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse ().sqrt ();
168
168
inv_std = saved_inv_std_arr;
@@ -171,7 +171,7 @@ class BatchNorm : public OpKernel {
171
171
172
172
// If we're training, do batch normalization based on computation from this batch
173
173
ConstEigenVectorArrayMap<T> mean_arr (
174
- #ifdef ENABLE_TRAINING_CORE
174
+ #ifdef ENABLE_TRAINING_OPS
175
175
!is_train_ ? mean->Data <T>() : saved_mean->Data <T>(),
176
176
#else
177
177
mean->Data <T>(),
0 commit comments