From 1acc0c1f90c783cf5b5ed949fc20bb305559e584 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 04:10:36 +0900 Subject: [PATCH] [Relay] Add `conv2d_backward_weight` op (without topi) (#9954) * python plumbing * add cpp def * legalize worked * clean up * layout conversion doesnt work * extract wgrad body * fix convert layout * black * fix kernel size * revert irrelevant change * add doc, clarify the meanings of parameters * update layout convert * test passed * fixed layout conversion * update convert layout * remove print * remove layout convert for now * minor fix * removed unused import * add wgrad python reference * add test stub * add doc * test other stride and pad * tweak * more pylint filter * fix typo in doc * swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad) --- python/tvm/relay/op/_tensor_grad.py | 53 +++-------- python/tvm/relay/op/nn/_nn.py | 78 ++++++++++++++++ python/tvm/relay/op/nn/nn.py | 51 +++++++++++ python/tvm/relay/testing/__init__.py | 1 + python/tvm/topi/testing/__init__.py | 1 + .../testing/conv2d_backcward_weight_python.py | 76 ++++++++++++++++ src/relay/op/nn/convolution.cc | 89 +++++++++++++++++++ tests/python/relay/test_op_grad_level2.py | 44 +++++++-- 8 files changed, 344 insertions(+), 49 deletions(-) create mode 100644 python/tvm/topi/testing/conv2d_backcward_weight_python.py diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 3793f947c5cc..a3e5f110d365 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -52,7 +52,6 @@ reshape_like, strided_slice, take, - tile, transpose, where, repeat, @@ -399,15 +398,14 @@ def conv2d_grad(orig, grad): data_shape = get_const_tuple(data.checked_type.shape) weight_shape = get_const_tuple(weight.checked_type.shape) _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape) - batch, in_channel, in_h, in_w = data_shape - out_channel, _, filter_h, filter_w = weight_shape + _, _, in_h, in_w = data_shape + _, _, filter_h, filter_w = weight_shape # infer output_padding fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( get_const_tuple(attrs.padding), (filter_h, filter_w) ) stride_h, stride_w = get_const_tuple(attrs.strides) - dilation_h, dilation_w = get_const_tuple(attrs.dilation) out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w output_padding = (in_h - out_h, in_w - out_w) @@ -425,46 +423,21 @@ def conv2d_grad(orig, grad): groups=attrs.groups, output_padding=output_padding, ) - grad = tile(grad, [1, in_channel // attrs.groups, 1, 1]) - grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow - data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw - backward_weight = _nn.conv2d( - data, + backward_weight = _nn.conv2d_backward_weight( grad, - strides=attrs.dilation, + data, + strides=attrs.strides, padding=attrs.padding, - dilation=attrs.strides, - groups=in_channel * batch, - ) - # infer shape of backward_weight - padded_weight_grad_h = ( - in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom - ) // dilation_h + 1 - padded_weight_grad_w = ( - in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right - ) // dilation_w + 1 - backward_weight = reshape( - backward_weight, - [ - batch, - in_channel // attrs.groups, - out_channel, - padded_weight_grad_h, - padded_weight_grad_w, - ], + dilation=attrs.dilation, + groups=attrs.groups, + channels=attrs.channels, + kernel_size=(filter_h, filter_w), + grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout, + data_layout=attrs.data_layout, + kernel_layout=attrs.kernel_layout, + out_dtype=attrs.out_dtype, ) - backward_weight = _sum(backward_weight, axis=0) - backward_weight = transpose(backward_weight, [1, 0, 2, 3]) - - assert padded_weight_grad_h >= filter_h - assert padded_weight_grad_w >= filter_w - if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: - backward_weight = strided_slice( - backward_weight, - begin=[0, 0, 0, 0], - end=[out_channel, in_channel // attrs.groups, filter_h, filter_w], - ) return [backward_data, backward_weight] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 9aa883d9b750..2a941cc8c28a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -23,6 +23,7 @@ from tvm.runtime import convert from tvm.te.hybrid import script from tvm.topi.utils import get_const_tuple +from tvm.topi.nn.utils import get_pad_tuple from ....ir import container from ....tir import expr @@ -1061,6 +1062,83 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_injective_schedule("nn.batch_to_space_nd") +@reg.register_legalize("nn.conv2d_backward_weight") +def legalize_conv2d_backward_weight(attrs, inputs, types): + """Legalize conv2d_backward_weight op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + grad, data = inputs + data_shape = get_const_tuple(data.checked_type.shape) + weight_shape = get_const_tuple(types[2].shape) + _, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape) + batch, in_channel, in_h, in_w = data_shape + _, _, filter_h, filter_w = weight_shape + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( + get_const_tuple(attrs.padding), (filter_h, filter_w) + ) + stride_h, stride_w = get_const_tuple(attrs.strides) + dilation_h, dilation_w = get_const_tuple(attrs.dilation) + + grad = relay.tile(grad, [1, in_channel // attrs.groups, 1, 1]) + grad = relay.reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow + data = relay.reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw + + backward_weight = relay.nn.conv2d( + data, + grad, + strides=attrs.dilation, + padding=attrs.padding, + dilation=attrs.strides, + groups=in_channel * batch, + ) + + # infer shape of backward_weight + padded_weight_grad_h = ( + in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom + ) // dilation_h + 1 + padded_weight_grad_w = ( + in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right + ) // dilation_w + 1 + + backward_weight = relay.reshape( + backward_weight, + [ + batch, + in_channel // attrs.groups, + out_channel, + padded_weight_grad_h, + padded_weight_grad_w, + ], + ) + backward_weight = relay.sum(backward_weight, axis=0) + backward_weight = relay.transpose(backward_weight, [1, 0, 2, 3]) + + assert padded_weight_grad_h >= filter_h + assert padded_weight_grad_w >= filter_w + + if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: + backward_weight = relay.strided_slice( + backward_weight, + begin=[0, 0, 0, 0], + end=[out_channel, in_channel // attrs.groups, filter_h, filter_w], + ) + + return backward_weight + + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index c7b376ec3d64..857e4c3eb9ba 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3770,3 +3770,54 @@ def batch_to_space_nd(data, block_shape, crops): """ return _make.batch_to_space_nd(data, block_shape, crops) + + +def conv2d_backward_weight( + grad, + data, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + grad_layout="NCHW", + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="", +): + r"""The gradient of conv2d with respect to weight. + + This operator takes the output gradient `grad` and convolves it with `data` as + the convolution kernel, to produce the gradient with respect to weight. + + Note that the parameter `kernel_size` is the spatial size of the corresponding + forward convolution kernel, not that of `data`. `grad_layout` and + `kernel_layout` are the layouts of `grad` and the weight gradient respectively. + + Other parameters are the same as the conv2d op. See its documentation for more + details. + + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + padding = get_pad_tuple2d(padding) + + return _make.conv2d_backward_weight( + grad, + data, + strides, + padding, + dilation, + groups, + channels, + kernel_size, + grad_layout, + data_layout, + kernel_layout, + out_dtype, + ) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 9fc75199bdf5..909712511061 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -127,6 +127,7 @@ def check_grad( fwd_func = run_infer_type(func) bwd_func = run_infer_type(gradient(fwd_func, mode=mode)) + bwd_func = run_opt_pass(bwd_func, relay.transform.Legalize()) if scale is None: scale = 10 * eps diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 345886c2be91..75eabffc957a 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -75,3 +75,4 @@ from .nll_loss import nll_loss from .dense import dense from .searchsorted import searchsorted_ref +from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py b/python/tvm/topi/testing/conv2d_backcward_weight_python.py new file mode 100644 index 000000000000..587cd45b49c1 --- /dev/null +++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-nested-blocks +"""Gradient of conv2d with respect to weight in python""" +import numpy as np + + +# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding): + """Gradient of the conv2d op with respect to weight, in NCHW layout. + + Parameters + ---------- + dy_np : numpy.ndarray + 4-D with shape [batch, in_channel, out_height, out_width] + + x_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel_size : tuple of two ints + Height and width of the weight + + stride : tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : tuple of two ints + Spatial padding, or [pad_h, pad_w] + + Returns + ------- + b_np : np.ndarray + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + + """ + N, C, H, W = x_np.shape + _, K, P, Q = dy_np.shape + R, S = kernel_size + pad_h, pad_w = padding + stride_h, stride_w = stride + dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) + + for k in range(K): + for r in range(R): + for s in range(S): + for c in range(C): + acc = 0 + for n in range(N): + for p in range(P): + for q in range(Q): + coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s) + + if ( + coord[2] < H + and coord[2] >= 0 + and coord[3] < W + and coord[3] >= 0 + ): + acc += dy_np[n, k, p, q] * x_np[coord] + + dw[k, c, r, s] = acc + + return dw diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 92164481807a..f1d4eb3d87ea 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -579,5 +579,94 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); }); +inline Expr MakeConv2dBackwardWeight(Expr grad, Expr data, Array strides, + Array padding, Array dilation, + int groups, IndexExpr channels, Array kernel_size, + std::string grad_layout, std::string data_layout, + std::string kernel_layout, DataType out_dtype) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->out_dtype = std::move(out_dtype); + attrs->data_layout = std::move(grad_layout); + attrs->kernel_layout = std::move(data_layout); + attrs->out_layout = std::move(kernel_layout); + const Op& op = Op::Get("nn.conv2d_backward_weight"); + return Call(op, {grad, data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_backward_weight") + .set_body_typed([](Expr grad, Expr data, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String grad_layout, String data_layout, + String kernel_layout, DataType out_dtype) { + return MakeConv2dBackwardWeight(grad, data, strides, padding, dilation, groups, channels, + kernel_size, grad_layout, data_layout, kernel_layout, + out_dtype); + }); + +bool Conv2DBackwardWeightRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* grad = types[0].as(); + const auto* data = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + // Require kernel_size to be passed, to simplify the output shape determination. + ICHECK(param->kernel_size.defined()) << "kernel_size attribute needs to be specified"; + + // We repurpose Conv2dAttrs for Conv2DBackwardWeight, note the meanings of layouts. + const Layout grad_layout(param->data_layout); + const Layout in_layout(param->kernel_layout); + const Layout kernel_layout(param->out_layout); + + const auto trans_grad_layout = tir::BijectiveLayout(grad_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + Array grad_shape_nchw = trans_grad_layout.ForwardShape(grad->shape); + + auto in_channels = dshape_nchw[1]; + auto out_channels = grad_shape_nchw[1]; + + Array wshape_oihw( + {out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]}); + + auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw); + reporter->Assign(types[2], TensorType(wshape, data->dtype)); + return true; +} + +RELAY_REGISTER_OP("nn.conv2d_backward_weight") + .describe(R"code(The gradient of the 2D convolution layer with respect to the weight. + +This layer computes the gradient of the conv2d op with respect to weight, +given the original input data and the output gradient. + +- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`. +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, in_channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (channels, in_channels, kernel_size[0], kernel_size[1]) if `layout` is `NCHW`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("grad", "Tensor", "The gradient tensor.") + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 115ed48d5888..1efdb262245f 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. import numpy as np - +import pytest from tvm import topi import tvm.topi.testing import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type, run_opt_pass from tvm.relay.transform import gradient import tvm.testing @@ -229,11 +229,37 @@ def test_batch_flatten_grad(): verify_batch_flatten_grad((1, 8)) +def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding): + dtype = "float32" + dy = relay.var("dy", shape=dy_shape, dtype=dtype) + x = relay.var("x", shape=x_shape, dtype=dtype) + dw = relay.nn.conv2d_backward_weight( + dy, x, strides=stride, padding=padding, kernel_size=kernel_size + ) + dw_func = relay.Function([dy, x], dw) + dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) + + target = "llvm" + dev = tvm.device(target, 0) + dy_np = np.random.randn(*dy_shape).astype(dtype) + x_np = np.random.randn(*x_shape).astype(dtype) + + dw_np = ( + relay.create_executor(device=dev, target=target) + .evaluate(dw_func_legalized)(dy_np, x_np) + .numpy() + ) + ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding + ) + + np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) + + +def test_conv2d_backward_weight(): + verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1)) + verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0)) + + if __name__ == "__main__": - test_max_pool2d_grad() - test_avg_pool2d_grad() - test_global_avg_pool2d_grad() - test_conv2d_grad() - test_dense_grad() - test_matmul_grad() - test_batch_flatten_grad() + pytest.main([__file__])