From 450a4e5637488f0521d22b38dc5fa439fd15e6a7 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Mon, 25 Jun 2018 19:33:54 +0530 Subject: [PATCH] [NNVM][ONNX] Cast operator updated. --- nnvm/python/nnvm/frontend/onnx.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index 3855d95086c8..a3f5f5b1494e 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -407,6 +407,8 @@ def _impl(inputs, attr, params): class Shape(OnnxOpConverter): + """ Operator converter for Shape. + """ @classmethod def _impl_v1(cls, inputs, attr, params): @@ -415,6 +417,24 @@ def _impl_v1(cls, inputs, attr, params): print("Shape: Differently implemented in NNVM as a bypass (dummy operator)") return inputs[0] +class Cast(OnnxOpConverter): + """ Operator converter for Cast. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) + + @classmethod + def _impl_v5(cls, inputs, attr, params): + try: + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']] + except ImportError as e: + raise ImportError( + "Unable to import onnx.mapping which is required {}".format(e)) + return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -516,7 +536,7 @@ def _get_convert_map(opset): # 'ArgMin' # defs/tensor - 'Cast': AttrCvt('cast', {'to': 'dtype'}), + 'Cast': Cast.get_converter(opset), 'Reshape': Reshape.get_converter(opset), 'Concat': Renamer('concatenate'), 'Split': AttrCvt('split', {'split': 'indices_or_sections'}),