-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-105] Fix CuDNN performance after code refactor #10116
Changes from all commits
01a667a
c616839
aaea301
5153679
3b3c959
583751e
1e71a10
4f27d57
457f51d
38a116b
746f604
9db50ca
9206dbd
1a7916a
975beb7
048604d
8fc682d
b0271a4
13ee008
b9fb194
c73ac46
afc7f75
f18dc20
287970f
db21715
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -690,13 +690,8 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs, | |
const OpContext& ctx, const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
CHECK_EQ(inputs.size(), 11U); | ||
CHECK_EQ(inputs.size(), 8U); | ||
BatchNormParam param = nnvm::get<BatchNormParam>(attrs.parsed); | ||
std::vector<TBlob> out_grad(1, inputs[0]); | ||
std::vector<TBlob> in_data(inputs.begin() + 3, inputs.begin() + 6); | ||
std::vector<TBlob> aux_states(inputs.begin() + 6, inputs.begin() + 8); | ||
std::vector<TBlob> out_data(inputs.begin() + 8, inputs.end()); | ||
std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3); | ||
int dtype = inputs[0].type_flag_; | ||
TShape shape = inputs[0].shape_; | ||
|
||
|
@@ -705,19 +700,18 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs, | |
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 | ||
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { | ||
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { | ||
GetCuDNNOp<DType>(param).Backward(ctx, out_grad, in_data, out_data, | ||
req, in_grad, aux_states); | ||
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs); | ||
}) | ||
} else { | ||
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { | ||
BatchNormBackward<gpu, DType, AccReal>(ctx, param, out_grad, | ||
in_data, out_data, req, in_grad, aux_states); | ||
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs); | ||
}) | ||
} | ||
#else | ||
aux_states[batchnorm::kMovingMean] = inputs[6]; | ||
aux_states[batchnorm::kMovingVar] = inputs[7]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zheng-da aux_states is not defined if USE_CUDNN is not enabled. @marcoabreu seems there is no pure cuda ci environment which is not built with cudnn. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i see. i'll update it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree. @marcoabreu could you add a CI only with CUDA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, no problem at all! Compilation only or do we need tests as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to run the code at least once. We probably don't need to try both Python2 and Python3, something like that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done: #10281 |
||
MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { | ||
BatchNormBackward<gpu, DType, AccReal>(ctx, param, out_grad, | ||
in_data, out_data, req, in_grad, aux_states); | ||
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs); | ||
}); | ||
#endif | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use reserve()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code runs to build the computation graph. It only runs once. Do we still need to call reserve()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, please