Skip to content

Commit

Permalink
Additional fix for PR#2972 (apache#3044)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 authored and wweic committed May 13, 2019
1 parent 914c1e5 commit 6030b9a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):

new_attrs = {k: attrs[k] for k in attrs.keys()}

if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]

Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
if "target" in new_attrs:
del new_attrs["target"]

if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]

Expand Down
5 changes: 2 additions & 3 deletions topi/python/topi/intel_graphics/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None

@conv2d_alter_layout.register(["intel_graphics"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
import nnvm.symbol as sym

copy_inputs = [s for s in inputs]

Expand All @@ -75,7 +74,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs["kernel_layout"] = 'OIHW%do' % (oc_bn)

if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]

Expand All @@ -84,7 +83,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
if "target" in new_attrs:
del new_attrs["target"]

if F == sym:
if F.__name__ == 'nnvm.symbol':
out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
else:
out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
Expand Down
12 changes: 6 additions & 6 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,26 +323,26 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym

copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}

if F == tvm.relay.op:
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]

data, kernel = tinfo[0], tinfo[1]
batch_size, in_channel, height, width = get_const_tuple(data.shape)

groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels") if F == sym else new_attrs["channels"]
out_channel = attrs.get_int("channels") \
if F.__name__ == 'nnvm.symbol' else new_attrs["channels"]
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_dtype = attrs["out_dtype"]

layout_name = 'layout' if F == sym else 'data_layout'
layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout'

layout = attrs[layout_name]
kh, kw = attrs.get_int_tuple("kernel_size")
Expand Down Expand Up @@ -404,12 +404,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
dispatch_ctx.update(target, new_workload, cfg)

if is_depthwise:
if F == sym:
if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for depthwise convolution on NNVM.")
return None
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
else:
if F == sym:
if F.__name__ == 'nnvm.symbol':
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)

Expand Down

0 comments on commit 6030b9a

Please sign in to comment.