From 01a667a4420a62c673aa8d8a6b09f08cf68222e8 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Tue, 13 Mar 2018 17:34:35 +0000 Subject: [PATCH 01/24] Reduce #inputs/outputs of batchnorm backward. --- src/operator/nn/batch_norm.cc | 45 +++++++++++++++++++++++++++++++---- src/operator/nn/batch_norm.cu | 12 ++++++---- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index c8b5d58156e5..f2e8ddca873f 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -470,8 +470,6 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, DispatchMode *dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 11); - CHECK_EQ(out_attrs->size(), 5); DispatchMode wanted_mode; #if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask) @@ -486,6 +484,45 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, dispatch_mode, wanted_mode); } +std::vector BatchNormGrad(const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector out_data(n->num_outputs()); + for (uint32_t i = 0; i < out_data.size(); ++i) { + out_data[i] = nnvm::NodeEntry{n, i, 0}; + } + std::vector heads; + heads.push_back(ograds[0]); + heads.push_back(out_data[batchnorm::kMean]); + heads.push_back(out_data[batchnorm::kVar]); + heads.push_back(n->inputs[batchnorm::kData]); + heads.push_back(n->inputs[batchnorm::kGamma]); + + // add all the auxiliary data + //for (uint32_t i = 0; i < prop.aux_states.size(); ++i) { + // inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]); + //} + nnvm::NodePtr gnode = nnvm::Node::Create(); + gnode->inputs = std::move(heads); + gnode->control_deps.emplace_back(n); + gnode->attrs = n->attrs; + gnode->attrs.op = nnvm::Op::Get("_backward_BatchNorm"); + gnode->attrs.name = n->attrs.name + "_backward"; + // The input of batchnorm + std::vector in_grad(5); + for (uint32_t i = 0; i < 3; ++i) { + in_grad[i] = nnvm::NodeEntry{gnode, i, 0}; + } + // attach no gradient node to forbid gradient on aux_state + nnvm::NodePtr ng = nnvm::Node::Create(); + ng->attrs.op = Op::Get("_NoGradient"); + ng->attrs.name = "NoGradient"; + // the aux state of batchnorm + for (uint32_t i = 0; i < 2; ++i) { + in_grad[i + 3] = nnvm::NodeEntry{ng, 0, 0}; + } + return in_grad; +} + NNVM_REGISTER_OP(BatchNorm) .describe(R"code(Batch normalization. @@ -559,7 +596,7 @@ then set ``gamma`` to 1 and its gradient to 0. #if MXNET_USE_MKLDNN == 1 .set_attr("FComputeEx", BatchNormComputeExCPU) #endif -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_BatchNorm"}) +.set_attr("FGradient", BatchNormGrad) #if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -583,7 +620,7 @@ then set ``gamma`` to 1 and its gradient to 0. }); NNVM_REGISTER_OP(_backward_BatchNorm) -.set_num_outputs(5) +.set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", backward_BatchNormStorageType) #if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index b8657fc4d367..facfae167b99 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -690,12 +690,16 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 11U); + CHECK_EQ(inputs.size(), 5U); BatchNormParam param = nnvm::get(attrs.parsed); std::vector out_grad(1, inputs[0]); - std::vector in_data(inputs.begin() + 3, inputs.begin() + 6); - std::vector aux_states(inputs.begin() + 6, inputs.begin() + 8); - std::vector out_data(inputs.begin() + 8, inputs.end()); + std::vector out_data(3); + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + std::vector in_data(3); + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + std::vector aux_states(2); std::vector in_grad(outputs.begin(), outputs.begin() + 3); int dtype = inputs[0].type_flag_; TShape shape = inputs[0].shape_; From c616839cdc232822d569a95fc530a45a6d218f5b Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 00:28:36 +0000 Subject: [PATCH 02/24] Pass more arrays to BN. --- src/operator/nn/batch_norm.cc | 3 +++ src/operator/nn/batch_norm.cu | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index f2e8ddca873f..444354bd9bd6 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -496,6 +496,9 @@ std::vector BatchNormGrad(const nnvm::NodePtr& n, heads.push_back(out_data[batchnorm::kVar]); heads.push_back(n->inputs[batchnorm::kData]); heads.push_back(n->inputs[batchnorm::kGamma]); + heads.push_back(n->inputs[batchnorm::kBeta]); + heads.push_back(n->inputs[batchnorm::kInMovingMean]); + heads.push_back(n->inputs[batchnorm::kInMovingVar]); // add all the auxiliary data //for (uint32_t i = 0; i < prop.aux_states.size(); ++i) { diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index facfae167b99..917a39f750e6 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -690,7 +690,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 5U); + CHECK_EQ(inputs.size(), 8U); BatchNormParam param = nnvm::get(attrs.parsed); std::vector out_grad(1, inputs[0]); std::vector out_data(3); From aaea301ae67a2cda365ccbef3fceec0e518d698c Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 18:04:03 +0000 Subject: [PATCH 03/24] Make std::vector thread local. --- src/operator/nn/batch_norm.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 917a39f750e6..42c1766bae83 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -692,15 +692,16 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 8U); BatchNormParam param = nnvm::get(attrs.parsed); - std::vector out_grad(1, inputs[0]); - std::vector out_data(3); + static thread_local std::vector out_grad(1); + static thread_local std::vector out_data(3); + static thread_local std::vector in_data(3); + static thread_local std::vector aux_states(2); + out_grad[0] = inputs[0]; out_data[batchnorm::kMean] = inputs[1]; out_data[batchnorm::kVar] = inputs[2]; - std::vector in_data(3); in_data[batchnorm::kData] = inputs[3]; in_data[batchnorm::kGamma] = inputs[4]; - std::vector aux_states(2); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); + std::vector &in_grad = outputs; int dtype = inputs[0].type_flag_; TShape shape = inputs[0].shape_; From 515367976eb48801555a29d9c44727760ea139c7 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 18:27:17 +0000 Subject: [PATCH 04/24] Set inputs of BN backward for other cases. --- src/operator/nn/batch_norm.cc | 20 ++++++++++++------- src/operator/nn/batch_norm.cu | 4 ++++ .../nn/mkldnn/mkldnn_batch_norm-inl.h | 2 +- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 444354bd9bd6..c3ae65a60a1f 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -424,13 +424,19 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, // MKLDNN batchnorm only works well on the special MKLDNN layout. if (SupportMKLDNNBN(inputs[0], param) && (inputs[in_data_start].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { - std::vector out_grad(inputs.begin(), inputs.begin() + num_out_grads); - std::vector in_data(inputs.begin() + in_data_start, - inputs.begin() + aux_states_start); - std::vector aux_states(inputs.begin() + aux_states_start, - inputs.begin() + out_data_start); - std::vector out_data(inputs.begin() + out_data_start, inputs.end()); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); + static thread_local std::vector out_grad(1); + static thread_local std::vector out_data(3); + static thread_local std::vector in_data(3); + static thread_local std::vector aux_states(2); + out_grad[0] = inputs[0]; + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + in_data[batchnorm::kBeta] = inputs[5]; + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; + std::vector &in_grad = outputs; if (inputs[0].dtype() == mshadow::kFloat32) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 42c1766bae83..09e0f8789a46 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -714,12 +714,16 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, req, in_grad, aux_states); }) } else { + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, in_grad, aux_states); }) } #else + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, in_grad, aux_states); diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index a685ebfb4abe..d927b1218ee8 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -302,7 +302,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &in_grad, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); + CHECK_EQ(out_grad.size(), 1U); CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); From 3b3c9593a02ec17d4e5809d6e26e80daac1c973d Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 20:27:06 +0000 Subject: [PATCH 05/24] Fix for other cases. --- src/operator/nn/batch_norm-inl.h | 26 ++++++++++++++------------ src/operator/nn/batch_norm.cc | 2 +- src/operator/nn/batch_norm.cu | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 48638de20ccb..f36b3933b00e 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -261,19 +261,21 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 11U); + CHECK_EQ(inputs.size(), 8U); const BatchNormParam& param = nnvm::get(attrs.parsed); - int num_out_grads = param.output_mean_var ? 3U : 1U; - int in_data_start = 3; - int aux_states_start = in_data_start + batchnorm::kInMovingMean; - int out_data_start = in_data_start + batchnorm::kInMovingVar + 1; - std::vector out_grad(inputs.begin(), inputs.begin() + num_out_grads); - std::vector in_data(inputs.begin() + in_data_start, - inputs.begin() + aux_states_start); - std::vector aux_states(inputs.begin() + aux_states_start, - inputs.begin() + out_data_start); - std::vector out_data(inputs.begin() + out_data_start, inputs.end()); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); + static thread_local std::vector out_grad(1); + static thread_local std::vector out_data(3); + static thread_local std::vector in_data(3); + static thread_local std::vector aux_states(2); + out_grad[0] = inputs[0]; + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + in_data[batchnorm::kBeta] = inputs[5]; + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; + const std::vector &in_grad = outputs; MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index c3ae65a60a1f..d0b663ce55f7 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -436,7 +436,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, in_data[batchnorm::kBeta] = inputs[5]; aux_states[batchnorm::kMovingMean] = inputs[6]; aux_states[batchnorm::kMovingVar] = inputs[7]; - std::vector &in_grad = outputs; + const std::vector &in_grad = outputs; if (inputs[0].dtype() == mshadow::kFloat32) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 09e0f8789a46..6c1b63d70f28 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -701,7 +701,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, out_data[batchnorm::kVar] = inputs[2]; in_data[batchnorm::kData] = inputs[3]; in_data[batchnorm::kGamma] = inputs[4]; - std::vector &in_grad = outputs; + const std::vector &in_grad = outputs; int dtype = inputs[0].type_flag_; TShape shape = inputs[0].shape_; From 583751ea2bc0b96a108801faf7880172f92fc022 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 20:44:57 +0000 Subject: [PATCH 06/24] remove commented code. --- src/operator/nn/batch_norm.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index d0b663ce55f7..514045e43604 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -506,10 +506,6 @@ std::vector BatchNormGrad(const nnvm::NodePtr& n, heads.push_back(n->inputs[batchnorm::kInMovingMean]); heads.push_back(n->inputs[batchnorm::kInMovingVar]); - // add all the auxiliary data - //for (uint32_t i = 0; i < prop.aux_states.size(); ++i) { - // inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]); - //} nnvm::NodePtr gnode = nnvm::Node::Create(); gnode->inputs = std::move(heads); gnode->control_deps.emplace_back(n); From 1e71a10fde537870572c45c600b0b6eab748421e Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 20:50:14 +0000 Subject: [PATCH 07/24] fix a potential mem leak. --- src/operator/nn/batch_norm.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 514045e43604..84def36c4ee7 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -424,10 +424,10 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, // MKLDNN batchnorm only works well on the special MKLDNN layout. if (SupportMKLDNNBN(inputs[0], param) && (inputs[in_data_start].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { - static thread_local std::vector out_grad(1); - static thread_local std::vector out_data(3); - static thread_local std::vector in_data(3); - static thread_local std::vector aux_states(2); + std::vector out_grad(1); + std::vector out_data(3); + std::vector in_data(3); + std::vector aux_states(2); out_grad[0] = inputs[0]; out_data[batchnorm::kMean] = inputs[1]; out_data[batchnorm::kVar] = inputs[2]; From 4f27d57d980e9a024f16b3931116e93329d92f71 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 21:24:01 +0000 Subject: [PATCH 08/24] Fix a compile error in mkldnn. --- src/operator/nn/batch_norm.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 84def36c4ee7..572f69625d80 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -415,15 +415,11 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(inputs.size(), 11U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - int num_out_grads = param.output_mean_var ? 3U : 1U; - int in_data_start = 3; - int aux_states_start = in_data_start + batchnorm::kInMovingMean; - int out_data_start = in_data_start + batchnorm::kInMovingVar + 1; TShape shape = inputs[0].shape(); // MKLDNN batchnorm only works well on the special MKLDNN layout. if (SupportMKLDNNBN(inputs[0], param) - && (inputs[in_data_start].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { + && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); From 457f51dbed3b35debc7221c1deace90eddce89fb Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 14 Mar 2018 22:54:13 +0000 Subject: [PATCH 09/24] Fix an error. --- src/operator/nn/batch_norm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 572f69625d80..51f866701115 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -413,7 +413,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - CHECK_EQ(inputs.size(), 11U); + CHECK_EQ(inputs.size(), 8U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); TShape shape = inputs[0].shape(); From 38a116bc998bdec8a5e97676bc391d0ccc33139d Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 15 Mar 2018 19:16:47 +0000 Subject: [PATCH 10/24] reserve space for std::vector. --- src/operator/nn/batch_norm.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 84def36c4ee7..386366d1c43c 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -497,6 +497,7 @@ std::vector BatchNormGrad(const nnvm::NodePtr& n, out_data[i] = nnvm::NodeEntry{n, i, 0}; } std::vector heads; + heads.reserve(8); heads.push_back(ograds[0]); heads.push_back(out_data[batchnorm::kMean]); heads.push_back(out_data[batchnorm::kVar]); From 746f604801eef720a94bd67f80f15c7a2271f1bc Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 15 Mar 2018 21:10:48 +0000 Subject: [PATCH 11/24] Fix alignment. --- src/operator/nn/batch_norm.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 386366d1c43c..5cc92b7cd99e 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -518,14 +518,15 @@ std::vector BatchNormGrad(const nnvm::NodePtr& n, for (uint32_t i = 0; i < 3; ++i) { in_grad[i] = nnvm::NodeEntry{gnode, i, 0}; } + // attach no gradient node to forbid gradient on aux_state - nnvm::NodePtr ng = nnvm::Node::Create(); - ng->attrs.op = Op::Get("_NoGradient"); - ng->attrs.name = "NoGradient"; + nnvm::NodePtr ng = nnvm::Node::Create(); + ng->attrs.op = Op::Get("_NoGradient"); + ng->attrs.name = "NoGradient"; // the aux state of batchnorm - for (uint32_t i = 0; i < 2; ++i) { - in_grad[i + 3] = nnvm::NodeEntry{ng, 0, 0}; - } + for (uint32_t i = 0; i < 2; ++i) { + in_grad[i + 3] = nnvm::NodeEntry{ng, 0, 0}; + } return in_grad; } From 9db50cadfdd96d1d7765e92b92c7b3c179db1338 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 15 Mar 2018 21:11:10 +0000 Subject: [PATCH 12/24] Fix cpp unit test. --- tests/cpp/include/test_core_op.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 63f5c91911ed..7dc05fda2cc6 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -141,8 +141,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer static auto gradient = nnvm::Op::GetAttr("FGradient"); nnvm::FGradient grad_fun = gradient.get(op_, nullptr); if (grad_fun) { - std::vector out_grads; - std::vector entries = grad_fun(MakeNode(), out_grads); + auto n = MakeNode(); + std::vector out_grads(n->num_outputs()); + std::vector entries = grad_fun(n, out_grads); CHECK_GE(entries.size(), 1U); res.reserve(entries.size()); for (const nnvm::NodeEntry& node_entry : entries) { @@ -467,7 +468,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer input_shapes_ = input_shapes; // BWD Output shapes output_shapes = backward_for_op->input_shapes_; - CHECK_EQ(output_shapes.size(), inferred_num_outputs); + output_shapes.resize(inferred_num_outputs); } else { output_shapes = input_shapes; output_shapes.resize(inferred_num_outputs); From 1a7916a70cb3ad98ac132ec434d09c6941f63b57 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 15 Mar 2018 21:47:20 +0000 Subject: [PATCH 13/24] Fix BN CPP unit tests. --- tests/cpp/operator/batchnorm_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 4b08d985de3e..f19fccd8a580 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -77,10 +77,10 @@ enum ForwardOutputs { * \brief Backward */ enum BackwardInputs { - /* out_grad */ bwd_out_grad_Grad, bwd_out_grad_Mean, bwd_out_grad_Var, + /* out_grad */ bwd_out_grad_Grad, + /* out_data */ bwd_out_data_Mean, bwd_out_data_Var /* in_data */ bwd_in_data_Data, bwd_in_data_Gamma, bwd_in_data_Beta, /* aux_states */ bwd_aux_states_MovingMean, bwd_aux_states_MovingVar, - /* in_grad */ bwd_out_data_Data, bwd_out_data_Mean, bwd_out_data_Var }; enum BackwardOutputs { /* in_grad */ bwd_in_grad_Data /* Original input data */, From 975beb70bca94fd98f3717fad56dbb28a39f8c75 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 15 Mar 2018 21:58:34 +0000 Subject: [PATCH 14/24] Fix a compile error. --- tests/cpp/operator/batchnorm_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index f19fccd8a580..da3397579321 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -78,9 +78,9 @@ enum ForwardOutputs { */ enum BackwardInputs { /* out_grad */ bwd_out_grad_Grad, - /* out_data */ bwd_out_data_Mean, bwd_out_data_Var + /* out_data */ bwd_out_data_Mean, bwd_out_data_Var, /* in_data */ bwd_in_data_Data, bwd_in_data_Gamma, bwd_in_data_Beta, - /* aux_states */ bwd_aux_states_MovingMean, bwd_aux_states_MovingVar, + /* aux_states */ bwd_aux_states_MovingMean, bwd_aux_states_MovingVar }; enum BackwardOutputs { /* in_grad */ bwd_in_grad_Data /* Original input data */, From 048604d01cc4b45c39dde3cf649d0753486386df Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 16 Mar 2018 00:10:29 +0000 Subject: [PATCH 15/24] Fix compilation error. --- tests/cpp/operator/batchnorm_test.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index da3397579321..2f9de742a35a 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -250,17 +250,12 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor { test::try_fill(ctx().run_ctx, &GetBlob(bwd_aux_states_MovingMean), 0); test::try_fill(ctx().run_ctx, &GetBlob(bwd_aux_states_MovingVar), 1); - val = -.101; - test::patternFill(ctx().run_ctx, &GetBlob(bwd_out_data_Data), [&val]() -> double { - return val += 1; }); test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_data_Mean), 0.0); test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_data_Var), 1.0); val = -.001; test::patternFill(ctx().run_ctx, &GetBlob(bwd_out_grad_Grad), [&val]() -> double { return val += 0.01; }); - test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_grad_Mean), 0.0); - test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_grad_Var), 1.0); } const bool hasWeightAndBias_; // This will cause forward pass validation to fail From 8fc682da82b86592364a400ed344abf084b3a34f Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 16 Mar 2018 18:43:02 +0000 Subject: [PATCH 16/24] Move Op signature. --- src/operator/nn/mkldnn/mkldnn_act.cc | 4 +- src/operator/nn/mkldnn/mkldnn_base-inl.h | 105 ----------------- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 4 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 4 +- .../nn/mkldnn/mkldnn_deconvolution.cc | 4 +- src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_pooling.cc | 2 +- src/operator/operator_common.h | 109 ++++++++++++++++++ 8 files changed, 119 insertions(+), 115 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 71fdf4ca585b..8c19850ced38 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -93,7 +93,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); } -typedef MKLDNNParamOpSign MKLDNNActSignature; +typedef ParamOpSign MKLDNNActSignature; class MKLDNNActForward { std::shared_ptr fwd; @@ -137,7 +137,7 @@ class MKLDNNActForward { static MKLDNNActForward &GetActForward(const ActivationParam& param, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem) { - static thread_local std::unordered_map fwds; + static thread_local std::unordered_map fwds; MKLDNNActSignature key(param); key.AddSign(ctx.is_train); key.AddSign(param.act_type); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 1c583e1f671e..362f5fbde727 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -296,111 +296,6 @@ class MKLDNNStream { } }; -class MKLDNNOpSignature { - std::vector eles; - uint64_t hash; - - public: - MKLDNNOpSignature() { - hash = 0; - } - - explicit MKLDNNOpSignature(uint64_t hash) { - this->hash = hash; - } - - /* - * We provide different methods to add signature to an op. - * For operations, such as convolutin and fully connected, which determines - * the optimal data layout for the op, we only need to use the shape and data - * type to sign the op. For other operations, such as activation, which uses - * whatever layout in the input array, we have to use the shape, the data type - * and the layout to sign the op. - */ - - void AddSign(const mkldnn::memory &mem) { - auto desc = mem.get_primitive_desc().desc(); - hash = hash * 2 + desc.data.format; - eles.push_back(desc.data.format); - hash = hash * 2 + desc.data.data_type; - eles.push_back(desc.data.data_type); - for (int i = 0; i < desc.data.ndims; i++) { - hash = hash * 2 + desc.data.dims[i]; - eles.push_back(desc.data.dims[i]); - } - } - - void AddSign(const std::vector &arrs) { - for (auto &arr : arrs) { - AddSign(arr); - } - } - - void AddSign(const NDArray &arr) { - if (arr.IsMKLDNNData()) { - AddSign(*(arr.GetMKLDNNData())); - } else { - hash = hash * 2 + arr.dtype(); - eles.push_back(arr.dtype()); - AddSign(arr.shape()); - } - } - - void AddSign(const TShape &shape) { - for (size_t i = 0; i < shape.ndim(); i++) { - hash = hash * 2 + shape[i]; - eles.push_back(shape[i]); - } - } - - void AddSign(int val) { - hash = hash * 2 + val; - eles.push_back(val); - } - - bool operator==(const MKLDNNOpSignature &sign) const { - if (hash != sign.hash) - return false; - if (eles.size() != sign.eles.size()) - return false; - for (size_t i = 0; i < eles.size(); i++) - if (eles[i] != sign.eles[i]) - return false; - return true; - } - - uint64_t GetHash() const { - return hash; - } -}; - -struct MKLDNNOpHash { - size_t operator()(const MKLDNNOpSignature &sign) const { - return sign.GetHash(); - } -}; - -template -class MKLDNNParamOpSign: public MKLDNNOpSignature { - const ParamType param; - - static size_t hash(const ParamType ¶m) { - std::hash fn; - return fn(param); - } - - public: - explicit MKLDNNParamOpSign(const ParamType &_param): MKLDNNOpSignature( - hash(_param)), param(_param) { - } - - bool operator==(const MKLDNNParamOpSign &sign) const { - const MKLDNNOpSignature &this_upper = *this; - const MKLDNNOpSignature &other_upper = sign; - return this_upper == other_upper && param == sign.param; - } -}; - enum OutDataOp { Noop, CopyBack, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index d927b1218ee8..16f9874bd5c8 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -98,7 +98,7 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags)); } -typedef MKLDNNParamOpSign MKLDNNBNSignature; +typedef ParamOpSign MKLDNNBNSignature; class MKLDNNBNForward { std::shared_ptr data_m; @@ -184,7 +184,7 @@ template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const NDArray &in_data, unsigned flags) { - static thread_local std::unordered_map fwds; + static thread_local std::unordered_map fwds; MKLDNNBNSignature key(param); key.AddSign(ctx.is_train); key.AddSign(in_data); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 76efc244fc42..453221f9b377 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -226,13 +226,13 @@ class MKLDNNConvForward { } }; -typedef MKLDNNParamOpSign MKLDNNConvSignature; +typedef ParamOpSign MKLDNNConvSignature; static inline MKLDNNConvForward &GetConvFwd( const nnvm::NodeAttrs& attrs, bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { - static thread_local std::unordered_map fwds; + static thread_local std::unordered_map fwds; const ConvolutionParam& param = nnvm::get(attrs.parsed); MKLDNNConvSignature key(param); key.AddSign(is_train); diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index a0d3df7bb477..8e30a8f81376 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -289,14 +289,14 @@ static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, } } -typedef MKLDNNParamOpSign MKLDNNDeconvSignature; +typedef ParamOpSign MKLDNNDeconvSignature; static inline MKLDNNDeconvForward &GetDeconvFwd( const nnvm::NodeAttrs& attrs, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { static thread_local - std::unordered_map fwds; + std::unordered_map fwds; const DeconvolutionParam& param = nnvm::get(attrs.parsed); MKLDNNDeconvSignature key(param); // Here we can sign the conv op with NDArray because conv primitive will diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 61895b4d4423..2097d57ba92f 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -104,7 +104,7 @@ inline bool MKLDNNRequireWorkspace(const PoolingParam ¶m) { return param.pool_type != pool_enum::kAvgPooling; } -typedef MKLDNNParamOpSign MKLDNNPoolingSignature; +typedef ParamOpSign MKLDNNPoolingSignature; void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace); diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 86f13145eaa5..1aeb7d48dc35 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -188,7 +188,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, const NDArray &output) { static thread_local std::unordered_map pooling_fwds; + OpHash> pooling_fwds; bool with_workspace = is_train && MKLDNNRequireWorkspace(param); MKLDNNPoolingSignature key(param); diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 10581d14ba72..e87e3c7ae78b 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -489,6 +489,115 @@ inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); } +class OpSignature { + std::vector eles; + uint64_t hash; + + public: + OpSignature() { + hash = 0; + } + + explicit OpSignature(uint64_t hash) { + this->hash = hash; + } + + /* + * We provide different methods to add signature to an op. + * For operations, such as convolutin and fully connected, which determines + * the optimal data layout for the op, we only need to use the shape and data + * type to sign the op. For other operations, such as activation, which uses + * whatever layout in the input array, we have to use the shape, the data type + * and the layout to sign the op. + */ + + void AddSign(const mkldnn::memory &mem) { + auto desc = mem.get_primitive_desc().desc(); + hash = hash * 2 + desc.data.format; + eles.push_back(desc.data.format); + hash = hash * 2 + desc.data.data_type; + eles.push_back(desc.data.data_type); + for (int i = 0; i < desc.data.ndims; i++) { + hash = hash * 2 + desc.data.dims[i]; + eles.push_back(desc.data.dims[i]); + } + } + + void AddSign(const std::vector &arrs) { + for (auto &arr : arrs) { + AddSign(arr); + } + } + + void AddSign(const NDArray &arr) { +#if MXNET_USE_MKLDNN == 1 + if (arr.IsMKLDNNData()) { + AddSign(*(arr.GetMKLDNNData())); + } else { +#endif + hash = hash * 2 + arr.dtype(); + eles.push_back(arr.dtype()); + AddSign(arr.shape()); +#if MXNET_USE_MKLDNN == 1 + } +#endif + } + + void AddSign(const TShape &shape) { + for (size_t i = 0; i < shape.ndim(); i++) { + hash = hash * 2 + shape[i]; + eles.push_back(shape[i]); + } + } + + void AddSign(int val) { + hash = hash * 2 + val; + eles.push_back(val); + } + + bool operator==(const OpSignature &sign) const { + if (hash != sign.hash) + return false; + if (eles.size() != sign.eles.size()) + return false; + for (size_t i = 0; i < eles.size(); i++) + if (eles[i] != sign.eles[i]) + return false; + return true; + } + + uint64_t GetHash() const { + return hash; + } +}; + +struct OpHash { + size_t operator()(const OpSignature &sign) const { + return sign.GetHash(); + } +}; + +template +class ParamOpSign: public OpSignature { + const ParamType param; + + static size_t hash(const ParamType ¶m) { + std::hash fn; + return fn(param); + } + + public: + explicit ParamOpSign(const ParamType &_param): OpSignature( + hash(_param)), param(_param) { + } + + bool operator==(const ParamOpSign &sign) const { + const OpSignature &this_upper = *this; + const OpSignature &other_upper = sign; + return this_upper == other_upper && param == sign.param; + } +}; + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ From b0271a4bd5ff9e92e84ddc6b129e748b1755a1f1 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 16 Mar 2018 18:43:34 +0000 Subject: [PATCH 17/24] Cache CuDNN conv op. --- src/operator/nn/convolution-inl.h | 2 ++ src/operator/nn/convolution.cu | 29 ++++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index d0dd7dd27a60..c98a010774d7 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -124,6 +124,8 @@ struct ConvolutionParam : public dmlc::Parameter { } }; +typedef ParamOpSign ConvSignature; + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index d7f9e564a603..d02e790454d1 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -41,13 +41,32 @@ static CuDNNConvolutionOp &GetCuDNNConvOp(const ConvolutionParam& param, const std::vector& in_shape, const std::vector& out_shape, const Context& ctx) { #if DMLC_CXX11_THREAD_LOCAL - static thread_local CuDNNConvolutionOp op; + static thread_local std::unordered_map >, + OpHash> ops; #else - static MX_THREAD_LOCAL CuDNNConvolutionOp op; + static MX_THREAD_LOCAL std::unordered_map >, + OpHash> ops; #endif - op.Init(param, forward_compute_type, backward_compute_type, - in_shape, out_shape, ctx); - return op; + ConvSignature key(param); + key.AddSign(forward_compute_type); + key.AddSign(backward_compute_type); + key.AddSign(in_shape); + key.AddSign(out_shape); + key.AddSign(ctx.dev_id); + + auto it = ops.find(key); + if (it == ops.end()) { + std::shared_ptr> op(new CuDNNConvolutionOp()); + auto ins_ret = ops.insert(std::pair>>( + key, op)); + CHECK(ins_ret.second); + it = ins_ret.first; + it->second->Init(param, forward_compute_type, backward_compute_type, in_shape, + out_shape, ctx); + } + return *it->second; } #endif From 13ee0086b675951188d765ea4882b0455cea70bc Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 16 Mar 2018 19:04:05 +0000 Subject: [PATCH 18/24] Fix compile error. --- src/operator/operator_common.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index e87e3c7ae78b..691cbc6c6545 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -543,6 +543,12 @@ class OpSignature { #endif } + void AddSign(const std::vector &shapes) { + for (auto &shape : shapes) { + AddSign(shape); + } + } + void AddSign(const TShape &shape) { for (size_t i = 0; i < shape.ndim(); i++) { hash = hash * 2 + shape[i]; From b9fb1949c00ec04b01bf4ade12e6dd7c525f9e7a Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 16 Mar 2018 19:23:50 +0000 Subject: [PATCH 19/24] Fix compile error. --- src/operator/operator_common.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 691cbc6c6545..ac00a175d2a5 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -511,6 +511,7 @@ class OpSignature { * and the layout to sign the op. */ +#if MXNET_USE_MKLDNN == 1 void AddSign(const mkldnn::memory &mem) { auto desc = mem.get_primitive_desc().desc(); hash = hash * 2 + desc.data.format; @@ -522,6 +523,7 @@ class OpSignature { eles.push_back(desc.data.dims[i]); } } +#endif void AddSign(const std::vector &arrs) { for (auto &arr : arrs) { From c73ac46531a4d1832cb44a38a40b02958a2495ff Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 16 Mar 2018 21:31:27 +0000 Subject: [PATCH 20/24] Remove thread_local. --- src/operator/nn/batch_norm-inl.h | 45 ++++++++--------- src/operator/nn/batch_norm.cu | 21 ++------ src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 51 ++++++++++---------- 3 files changed, 49 insertions(+), 68 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index f36b3933b00e..3f47d58bb8c3 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -224,16 +224,25 @@ void BatchNormForward(const OpContext &ctx, const BatchNormParam& param, */ template void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, + const std::vector &inputs, const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(in_grad.size(), 3U); + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 8U); + CHECK_EQ(outputs.size(), 3U); + std::vector out_grad(1); + std::vector out_data(3); + std::vector in_data(3); + std::vector aux_states(2); + + out_grad[0] = inputs[0]; + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + in_data[batchnorm::kBeta] = inputs[5]; + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; + const std::vector &in_grad = outputs; mshadow::Stream *s = ctx.get_stream(); BatchNormBackwardImpl(s, ctx, param, out_grad, in_data, out_data, req, in_grad, aux_states); @@ -263,23 +272,9 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 8U); const BatchNormParam& param = nnvm::get(attrs.parsed); - static thread_local std::vector out_grad(1); - static thread_local std::vector out_data(3); - static thread_local std::vector in_data(3); - static thread_local std::vector aux_states(2); - out_grad[0] = inputs[0]; - out_data[batchnorm::kMean] = inputs[1]; - out_data[batchnorm::kVar] = inputs[2]; - in_data[batchnorm::kData] = inputs[3]; - in_data[batchnorm::kGamma] = inputs[4]; - in_data[batchnorm::kBeta] = inputs[5]; - aux_states[batchnorm::kMovingMean] = inputs[6]; - aux_states[batchnorm::kMovingVar] = inputs[7]; - const std::vector &in_grad = outputs; - MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { - BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, - in_grad, aux_states); + MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { + BatchNormBackward(ctx, param, inputs, req, outputs); }); } diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 6c1b63d70f28..c310a93d700f 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -692,16 +692,6 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 8U); BatchNormParam param = nnvm::get(attrs.parsed); - static thread_local std::vector out_grad(1); - static thread_local std::vector out_data(3); - static thread_local std::vector in_data(3); - static thread_local std::vector aux_states(2); - out_grad[0] = inputs[0]; - out_data[batchnorm::kMean] = inputs[1]; - out_data[batchnorm::kVar] = inputs[2]; - in_data[batchnorm::kData] = inputs[3]; - in_data[batchnorm::kGamma] = inputs[4]; - const std::vector &in_grad = outputs; int dtype = inputs[0].type_flag_; TShape shape = inputs[0].shape_; @@ -710,23 +700,18 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - GetCuDNNOp(param).Backward(ctx, out_grad, in_data, out_data, - req, in_grad, aux_states); + GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) } else { - aux_states[batchnorm::kMovingMean] = inputs[6]; - aux_states[batchnorm::kMovingVar] = inputs[7]; MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { - BatchNormBackward(ctx, param, out_grad, - in_data, out_data, req, in_grad, aux_states); + BatchNormBackward(ctx, param, inputs, req, outputs); }) } #else aux_states[batchnorm::kMovingMean] = inputs[6]; aux_states[batchnorm::kMovingVar] = inputs[7]; MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { - BatchNormBackward(ctx, param, out_grad, - in_data, out_data, req, in_grad, aux_states); + BatchNormBackward(ctx, param, inputs, req, outputs); }); #endif } diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index e2337049060e..e3d5dd9204b9 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -67,10 +67,10 @@ class CuDNNBatchNormOp { } void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 3U); @@ -158,29 +158,30 @@ class CuDNNBatchNormOp { } void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), 1U); - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(in_grad.size(), 3U); + CHECK_EQ(inputs.size(), 8U); + CHECK_EQ(outputs.size(), 3U); CHECK(ctx.is_train && !param_.use_global_stats) << "use global statistics is not yet supported in CuDNNBatchNorm"; - Init(in_data[cudnnbatchnorm::kData]); + // Rename the inputs and outputs. + const TBlob &out_grad = inputs[0]; + const TBlob &out_mean = inputs[1]; + const TBlob &out_var = inputs[2]; + const TBlob &in_data = inputs[3]; + const TBlob &in_gamma = inputs[4]; + const std::vector &in_grad = outputs; + + Init(in_data); Stream *s = ctx.get_stream(); - Tensor x = - in_data[cudnnbatchnorm::kData].get_with_shape(shape_, s); + Tensor x = in_data.get_with_shape(shape_, s); Tensor dx = in_grad[cudnnbatchnorm::kData].get_with_shape(shape_, s); - Tensor dy = - out_grad[cudnnbatchnorm::kOut].get_with_shape(shape_, s); + Tensor dy = out_grad.get_with_shape(shape_, s); #if CUDNN_VERSION >= 4007 #if CUDNN_VERSION >= 7002 @@ -190,15 +191,15 @@ class CuDNNBatchNormOp { #endif MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { Tensor gamma = - in_data[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); + in_gamma.get_with_shape(Shape1(shape_[1]), s); Tensor dbeta = in_grad[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); Tensor dgamma = in_grad[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); Tensor save_mean = - out_data[cudnnbatchnorm::kMean].get_with_shape(Shape1(shape_[1]), s); + out_mean.get_with_shape(Shape1(shape_[1]), s); Tensor save_inv_var = - out_data[cudnnbatchnorm::kInvVar].get_with_shape(Shape1(shape_[1]), s); + out_var.get_with_shape(Shape1(shape_[1]), s); typename DataType::ScaleType a = 1.0f; typename DataType::ScaleType b = 0.0f; @@ -232,15 +233,15 @@ class CuDNNBatchNormOp { #else // CUDNN_VERSION < 4007 MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { Tensor gamma = - in_data[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); + in_gamma.get_with_shape(Shape1(shape_[1]), s); Tensor dbeta = in_grad[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); Tensor dgamma = in_grad[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); Tensor save_mean = - out_data[cudnnbatchnorm::kMean].get_with_shape(Shape1(shape_[1]), s); + out_mean.get_with_shape(Shape1(shape_[1]), s); Tensor save_inv_var = - out_data[cudnnbatchnorm::kInvVar].get_with_shape(Shape1(shape_[1]), s); + out_var.get_with_shape(Shape1(shape_[1]), s); typename DataType::ScaleType a = 1.0f; typename DataType::ScaleType b = 0.0f; From afc7f75442e68beac04065c366a6a64bf5778b44 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 19 Mar 2018 23:34:12 +0000 Subject: [PATCH 21/24] Reduce mem alloc when caching cudnn conv. --- src/operator/nn/convolution.cu | 8 ++++++++ src/operator/operator_common.h | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index d02e790454d1..50859b9205a0 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -50,6 +50,14 @@ static CuDNNConvolutionOp &GetCuDNNConvOp(const ConvolutionParam& param, OpHash> ops; #endif ConvSignature key(param); + size_t ndim = 0; + for (auto &s: in_shape) + ndim += s.ndim(); + for (auto &s: out_shape) + ndim += s.ndim(); + key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + + ndim + 1 /* for dev_id */); + key.AddSign(forward_compute_type); key.AddSign(backward_compute_type); key.AddSign(in_shape); diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index ac00a175d2a5..a629ba5eed8b 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -502,6 +502,13 @@ class OpSignature { this->hash = hash; } + /* + * This is to reserve space for the vector. + */ + void Reserve(size_t num) { + eles.reserve(num); + } + /* * We provide different methods to add signature to an op. * For operations, such as convolutin and fully connected, which determines From f18dc20c86add61961608a3593be0ebbac117e38 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Tue, 20 Mar 2018 22:26:57 +0000 Subject: [PATCH 22/24] Fix a lint error. --- src/operator/nn/convolution.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index 50859b9205a0..f6d14e3558b8 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -51,9 +51,9 @@ static CuDNNConvolutionOp &GetCuDNNConvOp(const ConvolutionParam& param, #endif ConvSignature key(param); size_t ndim = 0; - for (auto &s: in_shape) + for (auto &s : in_shape) ndim += s.ndim(); - for (auto &s: out_shape) + for (auto &s : out_shape) ndim += s.ndim(); key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + ndim + 1 /* for dev_id */); From 287970f28d26e509d4985b51362b23aa292e72b6 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 21 Mar 2018 00:00:32 +0000 Subject: [PATCH 23/24] Cache CuDNN deconv. --- src/operator/nn/deconvolution-inl.h | 2 ++ src/operator/nn/deconvolution.cu | 33 +++++++++++++++++-- .../nn/mkldnn/mkldnn_deconvolution.cc | 8 ++--- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index badbb8b9d672..b41ecf4aa41e 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -169,6 +169,8 @@ struct DeconvolutionParam : public dmlc::Parameter { } }; +typedef ParamOpSign DeconvSignature; + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index c7395428c2a0..c50de8e55fe2 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -40,9 +40,36 @@ static CuDNNDeconvolutionOp &GetCuDNNDeconvOp(const DeconvolutionParam& p const std::vector& in_shape, const std::vector& out_shape, const Context& ctx) { - static thread_local CuDNNDeconvolutionOp op; - op.Init(param, forward_compute_type, backward_compute_type, in_shape, out_shape, ctx); - return op; + + static thread_local std::unordered_map >, + OpHash> ops; + DeconvSignature key(param); + size_t ndim = 0; + for (auto &s : in_shape) + ndim += s.ndim(); + for (auto &s : out_shape) + ndim += s.ndim(); + key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + + ndim + 1 /* for dev_id */); + + key.AddSign(forward_compute_type); + key.AddSign(backward_compute_type); + key.AddSign(in_shape); + key.AddSign(out_shape); + key.AddSign(ctx.dev_id); + + auto it = ops.find(key); + if (it == ops.end()) { + std::shared_ptr> op(new CuDNNDeconvolutionOp()); + auto ins_ret = ops.insert(std::pair>>( + key, op)); + CHECK(ins_ret.second); + it = ins_ret.first; + it->second->Init(param, forward_compute_type, backward_compute_type, in_shape, + out_shape, ctx); + } + return *it->second; } #endif diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 8e30a8f81376..af57b68cfd37 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -289,16 +289,14 @@ static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, } } -typedef ParamOpSign MKLDNNDeconvSignature; - static inline MKLDNNDeconvForward &GetDeconvFwd( const nnvm::NodeAttrs& attrs, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { static thread_local - std::unordered_map fwds; + std::unordered_map fwds; const DeconvolutionParam& param = nnvm::get(attrs.parsed); - MKLDNNDeconvSignature key(param); + DeconvSignature key(param); // Here we can sign the conv op with NDArray because conv primitive will // decide the right layout for the, so we only need to get the shape and the // data type of the arrays. @@ -313,7 +311,7 @@ static inline MKLDNNDeconvForward &GetDeconvFwd( bool has_bias = (bias != nullptr); MKLDNNDeconvForward fwd(param, data, weights, has_bias, output); auto ins_ret = fwds.insert( - std::pair(key, fwd)); + std::pair(key, fwd)); CHECK(ins_ret.second); it = ins_ret.first; } From db21715e1758943f0fab048b461eb7ecf039c5ce Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 21 Mar 2018 00:16:41 +0000 Subject: [PATCH 24/24] Fix lint error. --- src/operator/nn/deconvolution.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index c50de8e55fe2..086b47000b2c 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -40,7 +40,6 @@ static CuDNNDeconvolutionOp &GetCuDNNDeconvOp(const DeconvolutionParam& p const std::vector& in_shape, const std::vector& out_shape, const Context& ctx) { - static thread_local std::unordered_map >, OpHash> ops; @@ -62,8 +61,8 @@ static CuDNNDeconvolutionOp &GetCuDNNDeconvOp(const DeconvolutionParam& p auto it = ops.find(key); if (it == ops.end()) { std::shared_ptr> op(new CuDNNDeconvolutionOp()); - auto ins_ret = ops.insert(std::pair>>( - key, op)); + auto ins_ret = ops.insert( + std::pair>>(key, op)); CHECK(ins_ret.second); it = ins_ret.first; it->second->Init(param, forward_compute_type, backward_compute_type, in_shape,