diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index af8f25af4955..3e36559c0a7c 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -640,6 +640,9 @@ then set ``gamma`` to 1 and its gradient to 0. NNVM_REGISTER_OP(_backward_BatchNorm) .set_num_inputs(8) .set_num_outputs(3) +.set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { + return std::vector{6, 7}; // moving_mean, moving_var +}) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BatchNormStorageType) #if MXNET_USE_MKLDNN == 1 diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index cf6bc362eb47..60fd526e16c7 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -758,6 +758,32 @@ def transpose(shape): assert (layer(x).shape==ceil_out_shape) +@with_seed() +def test_batchnorm_backward_synchronization(): + """ + Tests if synchronization of BatchNorm running variables is done correctly. + If not, the test sometimes fails - depending on the timing. + """ + ctx = mx.test_utils.default_context() + + for variable in ['running_var', 'running_mean']: + for _ in range(20): + layer = nn.BatchNorm() + layer.initialize(ctx=ctx) + for _ in range(3): + data = mx.nd.random.normal(loc=10, scale=2, shape=(1, 3, 10, 10), ctx=ctx) + with mx.autograd.record(): + out = layer(data) + out.backward() + + # check if each read give the same value + var1 = getattr(layer, variable).data().asnumpy() + for _ in range(10): + var2 = getattr(layer, variable).data().asnumpy() + if (var1 != var2).any(): + raise AssertionError("Two consecutive reads of " + variable + " give different results") + + @with_seed() def test_batchnorm(): layer = nn.BatchNorm(in_channels=10)