Skip to content

Commit

Permalink
swap arg order (data, grad) to be consistent with conv2d_transpose(dg…
Browse files Browse the repository at this point in the history
…rad)
  • Loading branch information
masahi committed Jan 19, 2022
1 parent 1aad114 commit f91e9d9
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 29 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,16 +425,16 @@ def conv2d_grad(orig, grad):
)

backward_weight = _nn.conv2d_backward_weight(
data,
grad,
data,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.dilation,
groups=attrs.groups,
channels=attrs.channels,
kernel_size=(filter_h, filter_w),
data_layout=attrs.data_layout,
grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout,
data_layout=attrs.data_layout,
kernel_layout=attrs.kernel_layout,
out_dtype=attrs.out_dtype,
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
result : tvm.relay.Expr
The legalized expr
"""
data, grad = inputs
grad, data = inputs
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(types[2].shape)
_, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape)
Expand Down
16 changes: 7 additions & 9 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3773,28 +3773,26 @@ def batch_to_space_nd(data, block_shape, crops):


def conv2d_backward_weight(
data,
grad,
data,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
grad_layout="NCHW",
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="",
):
r"""The gradient of conv2d with respect to weight.
This operator takes the output gradient `grad` as the convolution kernel
and convolves it with `data` to produce the gradient with respect to weight.
Depending on an implementation, the roles of `data` and `grad` can be swapped
(For example, in CUTLASS `data` acts as the filter).
This operator takes the output gradient `grad` and convolves it with `data` as
the convolution kernel, to produce the gradient with respect to weight.
Note that the parameter `kernel_size` is the spatial size of the corresponding
forward convolution kernel, not that of `grad`. `grad_layout` and
forward convolution kernel, not that of `data`. `grad_layout` and
`kernel_layout` are the layouts of `grad` and the weight gradient respectively.
Other parameters are the same as the conv2d op. See its documentation for more
Expand All @@ -3810,16 +3808,16 @@ def conv2d_backward_weight(
padding = get_pad_tuple2d(padding)

return _make.conv2d_backward_weight(
data,
grad,
data,
strides,
padding,
dilation,
groups,
channels,
kernel_size,
data_layout,
grad_layout,
data_layout,
kernel_layout,
out_dtype,
)
32 changes: 16 additions & 16 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,10 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d");
});

inline Expr MakeConv2dBackwardWeight(Expr data, Expr grad, Array<IndexExpr> strides,
inline Expr MakeConv2dBackwardWeight(Expr grad, Expr data, Array<IndexExpr> strides,
Array<IndexExpr> padding, Array<IndexExpr> dilation,
int groups, IndexExpr channels, Array<IndexExpr> kernel_size,
std::string data_layout, std::string grad_layout,
std::string grad_layout, std::string data_layout,
std::string kernel_layout, DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
Expand All @@ -591,29 +591,29 @@ inline Expr MakeConv2dBackwardWeight(Expr data, Expr grad, Array<IndexExpr> stri
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->out_dtype = std::move(out_dtype);
attrs->kernel_layout = std::move(grad_layout);
attrs->data_layout = std::move(grad_layout);
attrs->kernel_layout = std::move(data_layout);
attrs->out_layout = std::move(kernel_layout);
const Op& op = Op::Get("nn.conv2d_backward_weight");
return Call(op, {data, grad}, Attrs(attrs), {});
return Call(op, {grad, data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_backward_weight")
.set_body_typed([](Expr data, Expr grad, Array<IndexExpr> strides, Array<IndexExpr> padding,
.set_body_typed([](Expr grad, Expr data, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Array<IndexExpr> kernel_size, String data_layout, String grad_layout,
Array<IndexExpr> kernel_size, String grad_layout, String data_layout,
String kernel_layout, DataType out_dtype) {
return MakeConv2dBackwardWeight(data, grad, strides, padding, dilation, groups, channels,
kernel_size, data_layout, grad_layout, kernel_layout,
return MakeConv2dBackwardWeight(grad, data, strides, padding, dilation, groups, channels,
kernel_size, grad_layout, data_layout, kernel_layout,
out_dtype);
});

bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* grad = types[1].as<TensorTypeNode>();
const auto* grad = types[0].as<TensorTypeNode>();
const auto* data = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;

static const Layout kNCHW("NCHW");
Expand All @@ -625,12 +625,12 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att
ICHECK(param->kernel_size.defined()) << "kernel_size attribute needs to be specified";

// We repurpose Conv2dAttrs for Conv2DBackwardWeight, note the meanings of layouts.
const Layout in_layout(param->data_layout);
const Layout grad_layout(param->kernel_layout);
const Layout grad_layout(param->data_layout);
const Layout in_layout(param->kernel_layout);
const Layout kernel_layout(param->out_layout);

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
const auto trans_grad_layout = tir::BijectiveLayout(grad_layout, kNCHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);

Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
Expand All @@ -653,16 +653,16 @@ RELAY_REGISTER_OP("nn.conv2d_backward_weight")
This layer computes the gradient of the conv2d op with respect to weight,
given the original input data and the output gradient.
- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(channels, in_channels, kernel_size[0], kernel_size[1]) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("grad", "Tensor", "The gradient tensor.")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel)
.set_attr<TNonComputational>("TNonComputational", true)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin
dy = relay.var("dy", shape=dy_shape, dtype=dtype)
x = relay.var("x", shape=x_shape, dtype=dtype)
dw = relay.nn.conv2d_backward_weight(
x, dy, strides=stride, padding=padding, kernel_size=kernel_size
dy, x, strides=stride, padding=padding, kernel_size=kernel_size
)
dw_func = relay.Function([dy, x], dw)
dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())
Expand Down

0 comments on commit f91e9d9

Please sign in to comment.