Skip to content

Commit

Permalink
[FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. (a…
Browse files Browse the repository at this point in the history
…pache#2850)

* [FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay.

* 	* test cases

* 	* ci error
  • Loading branch information
srkreddy1238 authored and wweic committed Apr 7, 2019
1 parent 714cc68 commit d28e2fe
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 27 deletions.
9 changes: 8 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ def __call__(self, inputs, attrs, *args):
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)

# ignore 'tvm_custom' always
self._ignores.append('tvm_custom')

# convert attributes
new_attrs = {}
for k in attrs.keys():
Expand All @@ -329,7 +333,8 @@ def __call__(self, inputs, attrs, *args):
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
if k != 'tvm_custom':
logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
Expand Down Expand Up @@ -416,4 +421,6 @@ def __init__(self, new_name):
self._new_name = new_name

def __call__(self, inputs, attrs, *args):
if 'tvm_custom' in attrs:
attrs.pop('tvm_custom')
return get_relay_op(self._new_name)(*inputs, **attrs)
44 changes: 33 additions & 11 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _impl_v1(cls, inputs, attr, params):
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
ignores=['dilations', 'auto_pad'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params)
Expand Down Expand Up @@ -160,6 +160,7 @@ def _impl_v1(cls, inputs, attr, params):
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)},
ignores=['auto_pad'],
custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3
if use_bias:
Expand Down Expand Up @@ -332,7 +333,21 @@ def _impl_v1(cls, inputs, attr, params):
shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape)
else:
out = _op.reshape_like(inputs[0], inputs[1])
# Try to infer shape by precompute prune if possible.
# TODO: good to check inputs to be in params.
# to be enhanced when relay support list_input_names API of NNVM
logging.warning("Infering Reshape argument by precompute")
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))

return out

Expand Down Expand Up @@ -477,10 +492,7 @@ class Shape(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)")
return inputs[0]
return _op.shape_of(inputs[0])

class Cast(OnnxOpConverter):
""" Operator converter for Cast.
Expand All @@ -494,7 +506,7 @@ def _impl_v1(cls, inputs, attr, params):
def _impl_v5(cls, inputs, attr, params):
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
except ImportError as e:
raise ImportError(
"Unable to import onnx.mapping which is required {}".format(e))
Expand Down Expand Up @@ -674,6 +686,11 @@ class ReduceMean(Reduce):
"""
name = 'mean'

class ReduceProd(Reduce):
""" Operator converter for ArgMax.
"""
name = 'prod'

class ArgMax(OnnxOpConverter):
""" Operator converter for ArgMax.
"""
Expand Down Expand Up @@ -826,6 +843,7 @@ def _get_convert_map(opset):
'ReduceMin': ReduceMin.get_converter(opset),
'ReduceSum': ReduceSum.get_converter(opset),
'ReduceMean': ReduceMean.get_converter(opset),
'ReduceProd': ReduceProd.get_converter(opset),
# 'ReduceProd'
# 'ReduceLogSumExp'
'ArgMax': ArgMax.get_converter(opset),
Expand All @@ -842,8 +860,7 @@ def _get_convert_map(opset):
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
# TODO(zhreshold) Shape op is implemented as bypass op in relay
# 'Shape': Shape.get_converter(opset),
'Shape': Shape.get_converter(opset),
}


Expand Down Expand Up @@ -883,6 +900,7 @@ def from_onnx(self, graph, opset):
----------
graph : onnx protobuf object
The loaded onnx graph
opset : opset version
Returns
Expand Down Expand Up @@ -911,12 +929,12 @@ def from_onnx(self, graph, opset):
dtype=self._params[i_name].dtype)
else:
self._num_input += 1
shape = self._shape[i_name] if i_name in self._shape else ()
tshape = self._shape[i_name] if i_name in self._shape else ()
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype)
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
Expand All @@ -936,6 +954,10 @@ def from_onnx(self, graph, opset):
self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
inputs.append(self._nodes[i_name])

i_name = self._parse_value_proto(node)
attr['tvm_custom'] = {}
attr['tvm_custom']['name'] = i_name

op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output)
if not isinstance(op, _expr.TupleWrapper):
Expand Down
31 changes: 16 additions & 15 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,35 +113,36 @@ def test_reshape():

tvm.testing.assert_allclose(ref_shape, tvm_out.shape)

def test_reshape_like():
def test_shape():
in_shape = (4, 3, 3, 4)
ref_shape = (3, 4, 4, 3)
ref_shape = (6, 2, 4, 3)

ref_array = np.random.uniform(size=ref_shape).astype('float32')
ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.FLOAT,
data_type = onnx.TensorProto.INT32,
dims = ref_array.shape,
vals = ref_array.flatten().astype(float)))
copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
vals = ref_array.flatten().astype(int)))
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])

shape_node = helper.make_node("Shape", ['out'], ['final_out'])

graph = helper.make_graph([ref_node, copy_node, reshape_node],
"reshape_like_test",
graph = helper.make_graph([ref_node, reshape_node, shape_node],
"shape_test",
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
outputs = [helper.make_tensor_value_info("final_out",
TensorProto.FLOAT, list(ref_shape))])

model = helper.make_model(graph, producer_name='reshape_like_test')
model = helper.make_model(graph, producer_name='shape_test')

for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
x = np.random.uniform(size=in_shape).astype('int32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32')

tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
tvm.testing.assert_allclose(ref_shape, tvm_out)

def _test_power_iteration(x_shape, y_shape):
if isinstance(y_shape, int):
Expand Down Expand Up @@ -995,7 +996,7 @@ def test_LogSoftmax():

if __name__ == '__main__':
test_reshape()
test_reshape_like()
test_shape()
test_power()
test_squeeze()
test_unsqueeze()
Expand Down

0 comments on commit d28e2fe

Please sign in to comment.