Skip to content

Commit

Permalink
[Relay] Add conv2d_backward_weight op (without topi) (apache#9954)
Browse files Browse the repository at this point in the history
* python plumbing

* add cpp def

* legalize worked

* clean up

* layout conversion doesnt work

* extract wgrad body

* fix convert layout

* black

* fix kernel size

* revert irrelevant change

* add doc, clarify the meanings of parameters

* update layout convert

* test passed

* fixed layout conversion

* update convert layout

* remove print

* remove layout convert for now

* minor fix

* removed unused import

* add wgrad python reference

* add test stub

* add doc

* test other stride and pad

* tweak

* more pylint filter

* fix typo in doc

* swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad)
  • Loading branch information
masahi authored and yuanfz98 committed Jan 24, 2022
1 parent 71d941c commit 2553a1a
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 49 deletions.
53 changes: 13 additions & 40 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
reshape_like,
strided_slice,
take,
tile,
transpose,
where,
repeat,
Expand Down Expand Up @@ -399,15 +398,14 @@ def conv2d_grad(orig, grad):
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(weight.checked_type.shape)
_, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
batch, in_channel, in_h, in_w = data_shape
out_channel, _, filter_h, filter_w = weight_shape
_, _, in_h, in_w = data_shape
_, _, filter_h, filter_w = weight_shape

# infer output_padding
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
get_const_tuple(attrs.padding), (filter_h, filter_w)
)
stride_h, stride_w = get_const_tuple(attrs.strides)
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
output_padding = (in_h - out_h, in_w - out_w)
Expand All @@ -425,46 +423,21 @@ def conv2d_grad(orig, grad):
groups=attrs.groups,
output_padding=output_padding,
)
grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw

backward_weight = _nn.conv2d(
data,
backward_weight = _nn.conv2d_backward_weight(
grad,
strides=attrs.dilation,
data,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
)
# infer shape of backward_weight
padded_weight_grad_h = (
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
) // dilation_h + 1
padded_weight_grad_w = (
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
) // dilation_w + 1
backward_weight = reshape(
backward_weight,
[
batch,
in_channel // attrs.groups,
out_channel,
padded_weight_grad_h,
padded_weight_grad_w,
],
dilation=attrs.dilation,
groups=attrs.groups,
channels=attrs.channels,
kernel_size=(filter_h, filter_w),
grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout,
data_layout=attrs.data_layout,
kernel_layout=attrs.kernel_layout,
out_dtype=attrs.out_dtype,
)
backward_weight = _sum(backward_weight, axis=0)
backward_weight = transpose(backward_weight, [1, 0, 2, 3])

assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(
backward_weight,
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
)

return [backward_data, backward_weight]

Expand Down
78 changes: 78 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.runtime import convert
from tvm.te.hybrid import script
from tvm.topi.utils import get_const_tuple
from tvm.topi.nn.utils import get_pad_tuple

from ....ir import container
from ....tir import expr
Expand Down Expand Up @@ -1061,6 +1062,83 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_injective_schedule("nn.batch_to_space_nd")


@reg.register_legalize("nn.conv2d_backward_weight")
def legalize_conv2d_backward_weight(attrs, inputs, types):
"""Legalize conv2d_backward_weight op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
grad, data = inputs
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(types[2].shape)
_, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape)
batch, in_channel, in_h, in_w = data_shape
_, _, filter_h, filter_w = weight_shape
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
get_const_tuple(attrs.padding), (filter_h, filter_w)
)
stride_h, stride_w = get_const_tuple(attrs.strides)
dilation_h, dilation_w = get_const_tuple(attrs.dilation)

grad = relay.tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = relay.reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = relay.reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw

backward_weight = relay.nn.conv2d(
data,
grad,
strides=attrs.dilation,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
)

# infer shape of backward_weight
padded_weight_grad_h = (
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
) // dilation_h + 1
padded_weight_grad_w = (
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
) // dilation_w + 1

backward_weight = relay.reshape(
backward_weight,
[
batch,
in_channel // attrs.groups,
out_channel,
padded_weight_grad_h,
padded_weight_grad_w,
],
)
backward_weight = relay.sum(backward_weight, axis=0)
backward_weight = relay.transpose(backward_weight, [1, 0, 2, 3])

assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w

if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = relay.strided_slice(
backward_weight,
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
)

return backward_weight


#####################
# Shape functions #
#####################
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3770,3 +3770,54 @@ def batch_to_space_nd(data, block_shape, crops):
"""

return _make.batch_to_space_nd(data, block_shape, crops)


def conv2d_backward_weight(
grad,
data,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
grad_layout="NCHW",
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="",
):
r"""The gradient of conv2d with respect to weight.
This operator takes the output gradient `grad` and convolves it with `data` as
the convolution kernel, to produce the gradient with respect to weight.
Note that the parameter `kernel_size` is the spatial size of the corresponding
forward convolution kernel, not that of `data`. `grad_layout` and
`kernel_layout` are the layouts of `grad` and the weight gradient respectively.
Other parameters are the same as the conv2d op. See its documentation for more
details.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation)
padding = get_pad_tuple2d(padding)

return _make.conv2d_backward_weight(
grad,
data,
strides,
padding,
dilation,
groups,
channels,
kernel_size,
grad_layout,
data_layout,
kernel_layout,
out_dtype,
)
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def check_grad(

fwd_func = run_infer_type(func)
bwd_func = run_infer_type(gradient(fwd_func, mode=mode))
bwd_func = run_opt_pass(bwd_func, relay.transform.Legalize())

if scale is None:
scale = 10 * eps
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@
from .nll_loss import nll_loss
from .dense import dense
from .searchsorted import searchsorted_ref
from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
76 changes: 76 additions & 0 deletions python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-nested-blocks
"""Gradient of conv2d with respect to weight in python"""
import numpy as np


# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
"""Gradient of the conv2d op with respect to weight, in NCHW layout.
Parameters
----------
dy_np : numpy.ndarray
4-D with shape [batch, in_channel, out_height, out_width]
x_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
kernel_size : tuple of two ints
Height and width of the weight
stride : tuple of two ints
Stride size, or [stride_height, stride_width]
padding : tuple of two ints
Spatial padding, or [pad_h, pad_w]
Returns
-------
b_np : np.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
"""
N, C, H, W = x_np.shape
_, K, P, Q = dy_np.shape
R, S = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)

for k in range(K):
for r in range(R):
for s in range(S):
for c in range(C):
acc = 0
for n in range(N):
for p in range(P):
for q in range(Q):
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)

if (
coord[2] < H
and coord[2] >= 0
and coord[3] < W
and coord[3] >= 0
):
acc += dy_np[n, k, p, q] * x_np[coord]

dw[k, c, r, s] = acc

return dw
Loading

0 comments on commit 2553a1a

Please sign in to comment.