From f91e9d9b4498fae3a29a37baf8598d3d603dd410 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 09:33:57 +0900 Subject: [PATCH] swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad) --- python/tvm/relay/op/_tensor_grad.py | 4 +-- python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/op/nn/nn.py | 16 +++++------- src/relay/op/nn/convolution.cc | 32 +++++++++++------------ tests/python/relay/test_op_grad_level2.py | 2 +- 5 files changed, 27 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index cb3dcb6406fe..a3e5f110d365 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -425,16 +425,16 @@ def conv2d_grad(orig, grad): ) backward_weight = _nn.conv2d_backward_weight( - data, grad, + data, strides=attrs.strides, padding=attrs.padding, dilation=attrs.dilation, groups=attrs.groups, channels=attrs.channels, kernel_size=(filter_h, filter_w), - data_layout=attrs.data_layout, grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout, + data_layout=attrs.data_layout, kernel_layout=attrs.kernel_layout, out_dtype=attrs.out_dtype, ) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 58dfde81413d..2a941cc8c28a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1080,7 +1080,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types): result : tvm.relay.Expr The legalized expr """ - data, grad = inputs + grad, data = inputs data_shape = get_const_tuple(data.checked_type.shape) weight_shape = get_const_tuple(types[2].shape) _, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 257644994ffe..857e4c3eb9ba 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3773,28 +3773,26 @@ def batch_to_space_nd(data, block_shape, crops): def conv2d_backward_weight( - data, grad, + data, strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, channels=None, kernel_size=None, - data_layout="NCHW", grad_layout="NCHW", + data_layout="NCHW", kernel_layout="OIHW", out_dtype="", ): r"""The gradient of conv2d with respect to weight. - This operator takes the output gradient `grad` as the convolution kernel - and convolves it with `data` to produce the gradient with respect to weight. - Depending on an implementation, the roles of `data` and `grad` can be swapped - (For example, in CUTLASS `data` acts as the filter). + This operator takes the output gradient `grad` and convolves it with `data` as + the convolution kernel, to produce the gradient with respect to weight. Note that the parameter `kernel_size` is the spatial size of the corresponding - forward convolution kernel, not that of `grad`. `grad_layout` and + forward convolution kernel, not that of `data`. `grad_layout` and `kernel_layout` are the layouts of `grad` and the weight gradient respectively. Other parameters are the same as the conv2d op. See its documentation for more @@ -3810,16 +3808,16 @@ def conv2d_backward_weight( padding = get_pad_tuple2d(padding) return _make.conv2d_backward_weight( - data, grad, + data, strides, padding, dilation, groups, channels, kernel_size, - data_layout, grad_layout, + data_layout, kernel_layout, out_dtype, ) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 1d87a455e442..f1d4eb3d87ea 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -579,10 +579,10 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); }); -inline Expr MakeConv2dBackwardWeight(Expr data, Expr grad, Array strides, +inline Expr MakeConv2dBackwardWeight(Expr grad, Expr data, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, - std::string data_layout, std::string grad_layout, + std::string grad_layout, std::string data_layout, std::string kernel_layout, DataType out_dtype) { auto attrs = make_object(); attrs->strides = std::move(strides); @@ -591,29 +591,29 @@ inline Expr MakeConv2dBackwardWeight(Expr data, Expr grad, Array stri attrs->groups = groups; attrs->channels = std::move(channels); attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); attrs->out_dtype = std::move(out_dtype); - attrs->kernel_layout = std::move(grad_layout); + attrs->data_layout = std::move(grad_layout); + attrs->kernel_layout = std::move(data_layout); attrs->out_layout = std::move(kernel_layout); const Op& op = Op::Get("nn.conv2d_backward_weight"); - return Call(op, {data, grad}, Attrs(attrs), {}); + return Call(op, {grad, data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_backward_weight") - .set_body_typed([](Expr data, Expr grad, Array strides, Array padding, + .set_body_typed([](Expr grad, Expr data, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, - Array kernel_size, String data_layout, String grad_layout, + Array kernel_size, String grad_layout, String data_layout, String kernel_layout, DataType out_dtype) { - return MakeConv2dBackwardWeight(data, grad, strides, padding, dilation, groups, channels, - kernel_size, data_layout, grad_layout, kernel_layout, + return MakeConv2dBackwardWeight(grad, data, strides, padding, dilation, groups, channels, + kernel_size, grad_layout, data_layout, kernel_layout, out_dtype); }); bool Conv2DBackwardWeightRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* grad = types[1].as(); + const auto* grad = types[0].as(); + const auto* data = types[1].as(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); @@ -625,12 +625,12 @@ bool Conv2DBackwardWeightRel(const Array& types, int num_inputs, const Att ICHECK(param->kernel_size.defined()) << "kernel_size attribute needs to be specified"; // We repurpose Conv2dAttrs for Conv2DBackwardWeight, note the meanings of layouts. - const Layout in_layout(param->data_layout); - const Layout grad_layout(param->kernel_layout); + const Layout grad_layout(param->data_layout); + const Layout in_layout(param->kernel_layout); const Layout kernel_layout(param->out_layout); - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); const auto trans_grad_layout = tir::BijectiveLayout(grad_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); @@ -653,16 +653,16 @@ RELAY_REGISTER_OP("nn.conv2d_backward_weight") This layer computes the gradient of the conv2d op with respect to weight, given the original input data and the output gradient. +- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. -- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`. - **out**: This depends on the `layout` parameter. Output is 4D array of shape (channels, in_channels, kernel_size[0], kernel_size[1]) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") .add_argument("grad", "Tensor", "The gradient tensor.") + .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel) .set_attr("TNonComputational", true) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index a6535b1653b0..1efdb262245f 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -234,7 +234,7 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin dy = relay.var("dy", shape=dy_shape, dtype=dtype) x = relay.var("x", shape=x_shape, dtype=dtype) dw = relay.nn.conv2d_backward_weight( - x, dy, strides=stride, padding=padding, kernel_size=kernel_size + dy, x, strides=stride, padding=padding, kernel_size=kernel_size ) dw_func = relay.Function([dy, x], dw) dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())