From 0f8c0dc982a89326ab0adc6eb20c40d38c0d2853 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Wed, 26 Sep 2018 00:31:13 -0500 Subject: [PATCH] Only update BN moving stats during training --- nnvm/src/top/nn/nn.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index 52ecd67719ec..59479a234729 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -358,17 +358,17 @@ axis to be the last item in the input shape. inputs[3], mean, param.momentum); new_var = MakeMomentumNode(n->attrs.name + "_var_mom", inputs[4], var_unbiased, param.momentum); + + new_mean = MakeNode("_assign", n->attrs.name + "_mean_update", + {inputs[3], new_mean}); + new_var = MakeNode("_assign", n->attrs.name + "_var_update", + {inputs[4], new_var}); } else { mean = inputs[3]; var = inputs[4]; new_mean = inputs[3]; new_var = inputs[4]; } - - new_mean = MakeNode("_assign", n->attrs.name + "_mean_update", - {inputs[3], new_mean}); - new_var = MakeNode("_assign", n->attrs.name + "_var_update", - {inputs[4], new_var}); mean = compiler::ExpandBiasToMatchAxis(mean, in_dim, 1, ax); var = compiler::ExpandBiasToMatchAxis(var, in_dim, 1, ax);