Skip to content

Commit

Permalink
code style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker committed Apr 17, 2020
1 parent 13504e1 commit 498cf74
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,12 +896,11 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if(param->groups == 1 || is_depthwise_conv2d) {
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || //simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) //blocked layout
{
if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
Expand Down Expand Up @@ -951,10 +950,10 @@ Expr Conv2DBackwardTransform(const Call& call,
} else {
auto& wshape = weight->type_as<TensorTypeNode>()->shape;
Array<Integer> arr;
for(size_t i=0; i<wshape.size(); i++){
for (size_t i = 0; i < wshape.size(); i++) {
if(i == static_cast<size_t>(small_ko_axis) || i == static_cast<size_t>(big_ko_axis)) {
auto node = wshape[i].as<IntImmNode>();
if(!node) {
if (!node) {
// if the shape is not a constant, use normal transform
return transformer->NormalCallTransform(call.operator->());
}
Expand Down

0 comments on commit 498cf74

Please sign in to comment.