Skip to content

Commit

Permalink
Added layout conversion to tensorcore friendly format for conv2d and …
Browse files Browse the repository at this point in the history
…conv3d.
  • Loading branch information
jwfromm committed Apr 9, 2020
1 parent 794dd68 commit 1216772
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,17 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
# pylint: disable=import-outside-toplevel
from tvm import relay
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
new_attrs = dict(attrs)
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
return relay.nn.conv2d(data, weight, **new_attrs)
elif desired_layout == 'NHWC':
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'HWIO'
return relay.nn.conv2d(data, weight, **new_attrs)
else:
assert "Layout %s is not yet supported." % (desired_layout)
return None


Expand Down Expand Up @@ -183,6 +187,42 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv3d"""
return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)

@reg.register_convert_op_layout("nn.conv3d")
def convert_conv3d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for conv3d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
data, weight = inputs
new_attrs = dict(attrs)
if desired_layout == 'NCDHW':
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIDHW'
return relay.nn.conv3d(data, weight, **new_attrs)
elif desired_layout == "NDHWC":
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'DHWIO'
return relay.nn.conv3d(data, weight, **new_attrs)
else:
assert "Layout %s is not yet supported" % desired_layout
return None

# conv3d_winograd related operators
reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
strategy.conv3d_winograd_without_weight_transfrom_strategy)
Expand Down

0 comments on commit 1216772

Please sign in to comment.