diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index 90e79276697a3..431c06abb6fcf 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -775,80 +775,6 @@ def _get_convert_map(opset): } -supported_ops = set([ - 'Constant', - 'Identity', - 'ThresholdedRelu', - 'ScaledTanh', - 'ParametricSoftplus', - 'ConstantFill', - 'FC', - 'Scale', - 'ImageScaler', - 'Upsample' , - 'SpatialBN', - 'Add', - 'Sub', - 'Mul', - 'Div', - 'Neg', - 'Abs', - 'Reciprocal', - 'Floor', - 'Ceil', - 'Sqrt', - 'Relu', - 'LeakyRelu', - 'Selu', - 'Elu', - 'Exp', - 'Log', - 'Tanh', - 'Pow', - 'PRelu', - 'Sigmoid', - 'HardSigmoid', - 'Max', - 'Min', - 'Sum', - 'Mean', - 'Clip', - 'Softmax', - 'LogSoftmax', - 'Softsign', - 'SoftPlus', - 'Gemm', - 'MatMul', - 'AveragePool', - 'MaxPool', - 'Conv', - 'ConvTranspose', - 'GlobalAveragePool', - 'GlobalMaxPool', - 'BatchNormalization', - 'Dropout', - 'Flatten', - 'LRN', - 'ReduceMax', - 'ReduceMin', - 'ReduceSum', - 'ReduceMean', - 'ArgMax', - 'ArgMin', - 'Cast', - 'Reshape', - 'Concat', - 'Split', - 'Slice', - 'Transpose', - 'Gather', - 'Squeeze', - 'Unsqueeze', - 'Pad', - 'Shape', -]) - - class GraphProto(object): """A helper class for handling nnvm graph copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -899,15 +825,21 @@ def from_onnx(self, graph, opset): self._num_input += 1 self._nodes[i_name] = _sym.Variable(name=i_name) # get list of unsupported ops - unsupported_ops = [] + convert_map = _get_convert_map(opset) + unsupported_ops = set() for node in graph.node: op_name = node.op_type - if op_name not in supported_ops: - unsupported_ops.append(op_name) + if op_name not in convert_map and \ + op_name != 'Constant' and \ + op_name not in _identity_list: + unsupported_ops.add(op_name) if unsupported_ops: - msg = 'The following operators are not supported for frontend ONNX: {}' - unsupported_ops = str(unsupported_ops).strip('[]').replace("'", '') - raise tvm.error.OpNotImplemented(msg.format(unsupported_ops)) + msg = ['The following operators are not supported for frontend ONNX: '] + for i, op_name in enumerate(unsupported_ops): + msg.append(op_name) + if i != len(unsupported_ops) - 1: + msg.append(', ') + raise tvm.error.OpNotImplemented(''.join(msg)) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f8fc719b4916e..4c989f6f2d987 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -879,79 +879,6 @@ def _get_convert_map(opset): 'Shape': Shape.get_converter(opset), } -supported_ops = set([ - 'Constant', - 'Identity', - 'ThresholdedRelu', - 'ScaledTanh', - 'ParametricSoftplus', - 'ConstantFill', - 'FC', - 'Scale', - 'Upsample' , - 'SpatialBN', - 'Add', - 'Sub', - 'Mul', - 'Div', - 'Neg', - 'Abs', - 'Reciprocal', - 'Floor', - 'Ceil', - 'Sqrt', - 'Relu', - 'LeakyRelu', - 'Selu', - 'Elu', - 'Exp', - 'Log', - 'Tanh', - 'Pow', - 'PRelu', - 'Sigmoid', - 'HardSigmoid', - 'Max', - 'Min', - 'Sum', - 'Mean', - 'Clip', - 'Softmax', - 'LogSoftmax', - 'Softsign', - 'SoftPlus', - 'Gemm', - 'MatMul', - 'AveragePool', - 'MaxPool', - 'Conv', - 'ConvTranspose', - 'GlobalAveragePool', - 'GlobalMaxPool', - 'BatchNormalization', - 'Dropout', - 'Flatten', - 'LRN', - 'ReduceMax', - 'ReduceMin', - 'ReduceSum', - 'ReduceMean', - 'ReduceProd', - 'ArgMax', - 'ArgMin', - 'Cast', - 'Reshape', - 'Concat', - 'Split', - 'Slice', - 'Transpose', - 'Gather', - 'Squeeze', - 'Unsqueeze', - 'Pad', - 'Shape', -]) - class GraphProto(object): """A helper class for handling Relay expression copying from pb2.GraphProto. @@ -1025,15 +952,21 @@ def from_onnx(self, graph, opset): dtype = d_type self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) # get list of unsupported ops - unsupported_ops = [] + convert_map = _get_convert_map(opset) + unsupported_ops = set() for node in graph.node: op_name = node.op_type - if op_name not in supported_ops: - unsupported_ops.append(op_name) + if op_name not in convert_map and \ + op_name != 'Constant' and \ + op_name not in _identity_list: + unsupported_ops.add(op_name) if unsupported_ops: - unsupported_ops = str(unsupported_ops).strip('[]').replace("'", '') - msg = 'The following operators are not supported for frontend ONNX: {}' - raise tvm.error.OpNotImplemented(msg.format(unsupported_ops)) + msg = ['The following operators are not supported for frontend ONNX: '] + for i, op_name in enumerate(unsupported_ops): + msg.append(op_name) + if i != len(unsupported_ops) - 1: + msg.append(', ') + raise tvm.error.OpNotImplemented(''.join(msg)) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type