Skip to content

Commit

Permalink
[Bugfix] Conv1Dtranspose default kernel layout should be IOW (#14482)
Browse files Browse the repository at this point in the history
* fix conv1Dtranspose kernel layout

* fix conv1Dtranspose type checker

* fix mxnet layout
  • Loading branch information
rebel-jangys authored Apr 4, 2023
1 parent f8f7bc8 commit f5db8b7
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 17 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,10 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
"dimensions respectively. Convolution is applied on the"
"'W' dimension.");
TVM_ATTR_FIELD(kernel_layout)
.set_default("OIW")
.set_default("IOW")
.describe(
"Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"Dimension ordering of data and weight. Can be 'IOW', 'IOW16o16i', etc."
"'I', 'O', 'W' stands for input_channel, num_filter and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def _convert_dense(


def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=None):
is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"

if input_shape is None:
input_shape = keras_layer.input_shape
_check_data_format(keras_layer)
Expand All @@ -290,19 +292,21 @@ def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=N

if data_layout == "NWC":
kernel_layout = "WIO"
if is_deconv:
kernel_layout = "WOI"
else:
kernel_layout = "OIW"
if is_deconv:
kernel_layout = "IOW"
msg = (
"Kernel layout with {} is not supported for operator Convolution1D "
"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg.format(data_layout))

is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"

if is_deconv:
if kernel_layout == "OIW":
weight = weight.transpose([2, 0, 1])
if kernel_layout == "IOW":
weight = weight.transpose([2, 1, 0])
kernel_w, n_filters, _ = weight.shape
else:
kernel_w, _, n_filters = weight.shape
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _mx_conv1d_transpose(inputs, attrs):
if data_layout != "NCW":
raise tvm.error.OpAttributeInvalid('Only "NCW" data layout is supported for 1D Convolution')
channel_axis = 1
kernel_layout = "OIW"
kernel_layout = "IOW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,9 @@ def convolution(self, inputs, input_types):
else:
data_layout = "NCW"
kernel_layout = "OIW"
if use_transpose:
# Transposed convolutions have IOW layout.
kernel_layout = "IOW"

# Conv1d does not currently support grouped convolution so we convert it to conv2d
is_grouped_conv1d = False
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def conv1d_transpose(
channels=None,
kernel_size=None,
data_layout="NCW",
kernel_layout="OIW",
kernel_layout="IOW",
out_layout="",
output_padding=(0,),
out_dtype="",
Expand Down
18 changes: 10 additions & 8 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
if (data == nullptr) return false;

static const Layout kNCW("NCW");
static const Layout kOIW("OIW");
static const Layout kIOW("IOW");

const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
ICHECK(param != nullptr);
Expand All @@ -938,9 +938,9 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOW);
ICHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< "Conv only support kernel layouts that are convertible from IOW."
<< " But got " << kernel_layout;

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand Down Expand Up @@ -979,16 +979,18 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
ICHECK_EQ(param->kernel_size.size(), 1);
// check the size
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< "Conv1DTraspose: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
<< "Conv1DTraspose: shape of weight is inconsistent with channels, "
<< " out_channels // groups != weight.shape[1] "
<< " out_channels=" << param->channels << " groups=" << param->groups
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[0]));
}
channels = wshape[1];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down

0 comments on commit f5db8b7

Please sign in to comment.