Skip to content

Commit

Permalink
[NNVM][ONNX] Cast operator updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Jun 25, 2018
1 parent 7792a77 commit 450a4e5
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion nnvm/python/nnvm/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def _impl(inputs, attr, params):


class Shape(OnnxOpConverter):
""" Operator converter for Shape.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand All @@ -415,6 +417,24 @@ def _impl_v1(cls, inputs, attr, params):
print("Shape: Differently implemented in NNVM as a bypass (dummy operator)")
return inputs[0]

class Cast(OnnxOpConverter):
""" Operator converter for Cast.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)

@classmethod
def _impl_v5(cls, inputs, attr, params):
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
except ImportError as e:
raise ImportError(
"Unable to import onnx.mapping which is required {}".format(e))
return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)


# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down Expand Up @@ -516,7 +536,7 @@ def _get_convert_map(opset):
# 'ArgMin'

# defs/tensor
'Cast': AttrCvt('cast', {'to': 'dtype'}),
'Cast': Cast.get_converter(opset),
'Reshape': Reshape.get_converter(opset),
'Concat': Renamer('concatenate'),
'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
Expand Down

0 comments on commit 450a4e5

Please sign in to comment.