diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 51e712869cf41..20ff5cf39d6db 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -139,13 +139,17 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): # pylint: disable=import-outside-toplevel from tvm import relay data, weight = inputs - assert desired_layout == 'NCHW', \ - "Currently only transformation to NCHW layout is supported." + new_attrs = dict(attrs) if desired_layout == 'NCHW': - new_attrs = dict(attrs) new_attrs['data_layout'] = desired_layout new_attrs['kernel_layout'] = 'OIHW' return relay.nn.conv2d(data, weight, **new_attrs) + elif desired_layout == 'NHWC': + new_attrs['data_layout'] = desired_layout + new_attrs['kernel_layout'] = 'HWIO' + return relay.nn.conv2d(data, weight, **new_attrs) + else: + assert "Layout %s is not yet supported." % (desired_layout) return None @@ -183,6 +187,42 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type): """Alternate the layout of conv3d""" return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type) +@reg.register_convert_op_layout("nn.conv3d") +def convert_conv3d(attrs, inputs, tinfos, desired_layout): + """Convert Layout pass registration for conv3d op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layout : str + The desired layout + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + data, weight = inputs + new_attrs = dict(attrs) + if desired_layout == 'NCDHW': + new_attrs['data_layout'] = desired_layout + new_attrs['kernel_layout'] = 'OIDHW' + return relay.nn.conv3d(data, weight, **new_attrs) + elif desired_layout == "NDHWC": + new_attrs['data_layout'] = desired_layout + new_attrs['kernel_layout'] = 'DHWIO' + return relay.nn.conv3d(data, weight, **new_attrs) + else: + assert "Layout %s is not yet supported" % desired_layout + return None + # conv3d_winograd related operators reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform", strategy.conv3d_winograd_without_weight_transfrom_strategy)