Skip to content

Commit

Permalink
Remove duplicate map as per Jenny's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
markrogersjr committed Apr 20, 2019
1 parent 95529f5 commit 3d5f193
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 159 deletions.
90 changes: 10 additions & 80 deletions nnvm/python/nnvm/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,80 +775,6 @@ def _get_convert_map(opset):
}


supported_ops = set([
'Constant',
'Identity',
'ThresholdedRelu',
'ScaledTanh',
'ParametricSoftplus',
'ConstantFill',
'FC',
'Scale',
'ImageScaler',
'Upsample' ,
'SpatialBN',
'Add',
'Sub',
'Mul',
'Div',
'Neg',
'Abs',
'Reciprocal',
'Floor',
'Ceil',
'Sqrt',
'Relu',
'LeakyRelu',
'Selu',
'Elu',
'Exp',
'Log',
'Tanh',
'Pow',
'PRelu',
'Sigmoid',
'HardSigmoid',
'Max',
'Min',
'Sum',
'Mean',
'Clip',
'Softmax',
'LogSoftmax',
'Softsign',
'SoftPlus',
'Gemm',
'MatMul',
'AveragePool',
'MaxPool',
'Conv',
'ConvTranspose',
'GlobalAveragePool',
'GlobalMaxPool',
'BatchNormalization',
'Dropout',
'Flatten',
'LRN',
'ReduceMax',
'ReduceMin',
'ReduceSum',
'ReduceMean',
'ArgMax',
'ArgMin',
'Cast',
'Reshape',
'Concat',
'Split',
'Slice',
'Transpose',
'Gather',
'Squeeze',
'Unsqueeze',
'Pad',
'Shape',
])


class GraphProto(object):
"""A helper class for handling nnvm graph copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
Expand Down Expand Up @@ -899,15 +825,19 @@ def from_onnx(self, graph, opset):
self._num_input += 1
self._nodes[i_name] = _sym.Variable(name=i_name)
# get list of unsupported ops
unsupported_ops = []
convert_map = _get_convert_map(opset)
unsupported_ops = set()
for node in graph.node:
op_name = node.op_type
if op_name not in supported_ops:
unsupported_ops.append(op_name)
if op_name not in convert_map and op_name != 'Constant':
unsupported_ops.add(op_name)
if unsupported_ops:
msg = 'The following operators are not supported for frontend ONNX: {}'
unsupported_ops = str(unsupported_ops).strip('[]').replace("'", '')
raise tvm.error.OpNotImplemented(msg.format(unsupported_ops))
msg = ['The following operators are not supported for frontend ONNX: ']
for i, op_name in enumerate(unsupported_ops):
msg.append(op_name)
if i != len(unsupported_ops) - 1:
msg.append(', ')
raise tvm.error.OpNotImplemented(''.join(msg))
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
Expand Down
89 changes: 10 additions & 79 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,79 +879,6 @@ def _get_convert_map(opset):
'Shape': Shape.get_converter(opset),
}

supported_ops = set([
'Constant',
'Identity',
'ThresholdedRelu',
'ScaledTanh',
'ParametricSoftplus',
'ConstantFill',
'FC',
'Scale',
'Upsample' ,
'SpatialBN',
'Add',
'Sub',
'Mul',
'Div',
'Neg',
'Abs',
'Reciprocal',
'Floor',
'Ceil',
'Sqrt',
'Relu',
'LeakyRelu',
'Selu',
'Elu',
'Exp',
'Log',
'Tanh',
'Pow',
'PRelu',
'Sigmoid',
'HardSigmoid',
'Max',
'Min',
'Sum',
'Mean',
'Clip',
'Softmax',
'LogSoftmax',
'Softsign',
'SoftPlus',
'Gemm',
'MatMul',
'AveragePool',
'MaxPool',
'Conv',
'ConvTranspose',
'GlobalAveragePool',
'GlobalMaxPool',
'BatchNormalization',
'Dropout',
'Flatten',
'LRN',
'ReduceMax',
'ReduceMin',
'ReduceSum',
'ReduceMean',
'ReduceProd',
'ArgMax',
'ArgMin',
'Cast',
'Reshape',
'Concat',
'Split',
'Slice',
'Transpose',
'Gather',
'Squeeze',
'Unsqueeze',
'Pad',
'Shape',
])


class GraphProto(object):
"""A helper class for handling Relay expression copying from pb2.GraphProto.
Expand Down Expand Up @@ -1025,15 +952,19 @@ def from_onnx(self, graph, opset):
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
# get list of unsupported ops
unsupported_ops = []
convert_map = _get_convert_map(opset)
unsupported_ops = set()
for node in graph.node:
op_name = node.op_type
if op_name not in supported_ops:
unsupported_ops.append(op_name)
if op_name not in convert_map and op_name != 'Constant':
unsupported_ops.add(op_name)
if unsupported_ops:
unsupported_ops = str(unsupported_ops).strip('[]').replace("'", '')
msg = 'The following operators are not supported for frontend ONNX: {}'
raise tvm.error.OpNotImplemented(msg.format(unsupported_ops))
msg = ['The following operators are not supported for frontend ONNX: ']
for i, op_name in enumerate(unsupported_ops):
msg.append(op_name)
if i != len(unsupported_ops) - 1:
msg.append(', ')
raise tvm.error.OpNotImplemented(''.join(msg))
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
Expand Down

0 comments on commit 3d5f193

Please sign in to comment.