Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use enable_training_ops instead of enable_training_core
  • Loading branch information
baijumeswani committed Aug 10, 2023
commit 84b97e49cb00787a70d9430569b83147157e2812
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/cpu/nn/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class BatchNorm : public OpKernel {
}

if (is_train_) {
#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING_OPS
momentum_ = op_kernel_info.GetAttrOrDefault<float>("momentum", 0.9f);
ORT_ENFORCE(is_spatial_, "Training mode only supports spatial BN");
#else
Expand Down Expand Up @@ -84,7 +84,7 @@ class BatchNorm : public OpKernel {
// calculate sample_size (including all channels)
size_t sample_size_incl_all_channels = sample_size * C;

#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING_OPS
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc));

Expand All @@ -111,7 +111,7 @@ class BatchNorm : public OpKernel {
ConstEigenVectorArrayMap<T> scale_arr(scale->Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> bias_arr(B->Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);

#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING_OPS
// Note that we only support spatial BN for training
if (is_train_) {
EigenVectorArrayMap<T> saved_mean_arr(saved_mean->MutableData<T>(), C);
Expand Down Expand Up @@ -162,7 +162,7 @@ class BatchNorm : public OpKernel {
ConstEigenVectorArrayMap<T> var_arr(var->Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
inv_std = (var_arr + epsilon_).sqrt().inverse();
} else {
#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING_OPS
EigenVectorArrayMap<T> saved_inv_std_arr(saved_inv_std->MutableData<T>(), C);
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt();
inv_std = saved_inv_std_arr;
Expand All @@ -171,7 +171,7 @@ class BatchNorm : public OpKernel {

// If we're training, do batch normalization based on computation from this batch
ConstEigenVectorArrayMap<T> mean_arr(
#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING_OPS
!is_train_ ? mean->Data<T>() : saved_mean->Data<T>(),
#else
mean->Data<T>(),
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ TEST(BatchNormTest, BatchNorm2d_bfloat16) {
#endif // USE_DNNL

// TODO fix flaky test for CUDA
#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING_OPS
TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
Expand Down Expand Up @@ -936,7 +936,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
{kCudaExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
}
#endif // ENABLE_TRAINING_CORE
#endif // ENABLE_TRAINING_OPS

} // namespace test
} // namespace onnxruntime