diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index e8da60a1afb5..6b3f144619dd 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -693,7 +693,10 @@ def _impl(inputs, attr, params, mod): raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) if "kernel_layout" not in attr: - attr["kernel_layout"] = "DHWIO" if attr["data_format"] == "NDHWC" else "OIDHW" + if opname == "conv": + attr["kernel_layout"] = "DHWIO" if attr["data_format"] == "NDHWC" else "OIDHW" + elif opname == "conv_transpose": + attr["kernel_layout"] = "DHWOI" if attr["data_format"] == "NDHWC" else "IODHW" use_bias = len(inputs) == (3 if opname != "conv_transpose" else 4) channel_axis = 1 if attr["data_format"] == "NCDHW" else 4