diff --git a/python/tvm/relax/frontend/onnx_frontend.py b/python/tvm/relax/frontend/onnx_frontend.py index 737160213be7..70b9f9b8ea61 100644 --- a/python/tvm/relax/frontend/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx_frontend.py @@ -125,7 +125,7 @@ def get_converter(cls, opset): return getattr(cls, "_impl_v{}".format(version)) raise NotImplementedError( "opset version {} of {} not implemented".format(version, cls.__name__) - ) + ) class MatMul(OnnxOpConverter): @@ -135,14 +135,18 @@ class MatMul(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr): return bb.emit_te(topi.matmul, inputs[0], inputs[1]) + class Div(OnnxOpConverter): """Converts an onnx Div node into an equivalent Relax expression.""" + @classmethod def _impl_v14(cls, bb, inputs, attr): return bb.emit_te(topi.divide, inputs[0], inputs[1]) + class Sigmoid(OnnxOpConverter): """Converts an onnx Sigmoid node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): return bb.emit_te(topi.sigmoid, inputs[0]) @@ -150,26 +154,31 @@ def _impl_v13(cls, bb, inputs, attr): class Softmax(OnnxOpConverter): """Converts an onnx Softmax node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): axis = attr.get("axis", -1) return bb.emit_te(topi.nn.softmax, inputs[0], axis=axis) + class Transpose(OnnxOpConverter): """Converts an onnx Transpose node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): perm = attr.get("perm", None) return bb.emit_te(topi.transpose, inputs[0], axes=perm) + class Unsqueeze(OnnxOpConverter): """Converts an onnx Unsqueeze node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): input = inputs[0] axes = inputs[1] - if (isinstance(axes, relax.Constant)): + if isinstance(axes, relax.Constant): constant_axes = list(axes.data.numpy()) constant_axes = list(map(int, constant_axes)) constant_axes = sorted(constant_axes) @@ -179,6 +188,7 @@ def _impl_v13(cls, bb, inputs, attr): raise NotImplementedError("Unsqueeze with dynamic axes is not supported.") + class Concat(OnnxOpConverter): """Convert an onnx Concat node into an equivalent Relax expression.""" @@ -207,6 +217,8 @@ def _impl_v13(cls, bb, inputs, attr): class Cast(OnnxOpConverter): """Convert an onnx Cast node into an equivalent Relax expression.""" + """Convert an onnx Cast node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): to_type = get_type(attr["to"]) @@ -216,6 +228,8 @@ def _impl_v13(cls, bb, inputs, attr): class Gather(OnnxOpConverter): """Convert an onnx Gather node into an equivalent Relax expression.""" + """Convert an onnx Gather node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): # TODO This assumes positive only indices. @@ -255,9 +269,12 @@ def _impl_v13(cls, bb, inputs, attr): class Reshape(OnnxOpConverter): """Convert an onnx Reshape node into an equivalent Relax expression.""" + """Convert an onnx Reshape node into an equivalent Relax expression.""" + @classmethod def _impl_v13(cls, bb, inputs, attr): from tvm.script import relax as R + data = inputs[0] # TODO We assume new_shape is a constant, need to enable tensor input to reshape # for full support. @@ -265,6 +282,7 @@ def _impl_v13(cls, bb, inputs, attr): # Convert -1 dims in new_shape into positive equivalent. if -1 in new_shape: + breakpoint() data_shape = [dim.value for dim in data.shape.values] total_elements = np.prod(data_shape) new_product = 1 @@ -277,14 +295,15 @@ def _impl_v13(cls, bb, inputs, attr): if dim == -1: new_shape[i] = int(total_elements / new_product) - return bb.emit_te(topi.reshape, data, new_shape) + class Gelu(OnnxOpConverter): """Operator converter for Gelu from Microsoft onnxruntime contrib opset. gelu(x) = 0.5x(1 + erf(x/sqrt(2))) """ + @classmethod def _impl_v1(cls, bb, inputs, attr): x = inputs[0] @@ -297,15 +316,17 @@ def _impl_v1(cls, bb, inputs, attr): # Compute gelu term1 = bb.emit_te(topi.multiply, half, x) - erf = bb.emit_te(topi.erf, bb.emit_te(topi.divide, x, sqrt2)) + erf = bb.emit_te(topi.erf, bb.emit_te(topi.divide, x, sqrt2)) term2 = bb.emit_te(topi.add, one, erf) return bb.emit_te(topi.multiply, term1, term2) + class BiasGelu(OnnxOpConverter): """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) """ + @classmethod def _impl_v1(cls, bb, inputs, attr): x = inputs[0] @@ -317,12 +338,134 @@ def _impl_v1(cls, bb, inputs, attr): inp = bb.emit_te(topi.add, x, b) return Gelu._impl_v1(bb, [inp], attr) + class Where(OnnxOpConverter): """Convert an onnx Where node into an equivalent Relax expression.""" + @classmethod def _impl_v16(cls, bb, inputs, attr): return bb.emit_te(topi.where, *inputs) + +class Clip(OnnxOpConverter): + """Converts an onnx Clip node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + results = inputs[0] + if len(inputs) >= 2: + results = bb.emit_te(topi.maximum, results, inputs[1]) + if len(inputs) >= 3: + results = bb.emit_te(topi.minimum, results, inputs[2]) + return results + + +class Equal(OnnxOpConverter): + """Converts an onnx Equal node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.equal, inputs[0], inputs[1]) + + +class Shape(OnnxOpConverter): + """Converts an onnx Equal node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.shape, inputs[0], inputs[1]) + + +class Not(OnnxOpConverter): + """Converts an onnx Not node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.bitwise_not, inputs[0]) + + +class Tanh(OnnxOpConverter): + """Converts an onnx Tanh node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.tanh, inputs[0]) + + +class Sqrt(OnnxOpConverter): + """Converts an onnx Sqrt node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.sqrt, inputs[0]) + + +class Relu(OnnxOpConverter): + """Converts an onnx Relu node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.nn.relu, inputs[0]) + + +class Pow(OnnxOpConverter): + """Converts an onnx Pow node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.power, inputs[0], inputs[1]) + + +class Conv(OnnxOpConverter): + """Convert an onnx Conv node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + # not supported yet + assert "auto_pad" not in attr + assert "group" not in attr + # supported conv2d + return bb.emit_te( + topi.add, + bb.emit_te( + topi.nn.conv2d, + inputs[0], + inputs[1], + strides=attr.get("strides", 1), + padding=attr.get("pads", 0), + dilation=attr.get("dilations", 1), + ), + bb.emit_te(topi.expand_dims, inputs[2], axis=1, num_newaxis=2), + ) + + +class Erf(OnnxOpConverter): + """Converts an onnx Erf node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return bb.emit_te(topi.erf, inputs[0]) + + +class CumSum(OnnxOpConverter): + """Converts an onnx CumSum node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet" + if len(inputs) > 1: + # axis = int(infer_value(inputs[1], params).numpy()) + axis = inputs[1] + else: + axis = None + return bb.emit_te( + topi.cumsum, + data=inputs[0], + axis=axis, + exclusive=attr.get("exclusive", None), + ) + + def _get_convert_map(opset): return { "MatMul": MatMul.get_converter(opset), @@ -341,6 +484,17 @@ def _get_convert_map(opset): "Gelu": Gelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), "Where": Where.get_converter(opset), + "Clip": Clip.get_converter(opset), + "Equal": Equal.get_converter(opset), + "Shape": Shape.get_converter(opset), + "Not": Not.get_converter(opset), + "Tanh": Tanh.get_converter(opset), + "Sqrt": Sqrt.get_converter(opset), + "Relu": Relu.get_converter(opset), + "Conv": Conv.get_converter(opset), + "Pow": Pow.get_converter(opset), + "Erf": Erf.get_converter(opset), + "CumSum": CumSum.get_converter(opset), } @@ -630,4 +784,4 @@ def from_onnx(model, shape=None, dtype="float32", opset=None): ) # Use the graph proto as a scope so that ops can access other nodes if needed. - return g.from_onnx(graph, opset) \ No newline at end of file + return g.from_onnx(graph, opset) diff --git a/tests/python/relax/frontend/test_onnx_frontend.py b/tests/python/relax/frontend/test_onnx_frontend.py index 24c7a4e1cf97..cfeeb0d78774 100644 --- a/tests/python/relax/frontend/test_onnx_frontend.py +++ b/tests/python/relax/frontend/test_onnx_frontend.py @@ -32,14 +32,19 @@ import onnx from onnx import helper -from onnx import TensorProto, ModelProto +from onnx import TensorProto, ModelProto, ValueInfoProto import onnxruntime -def generate_random_inputs(model: ModelProto) -> Dict[str, np.array]: +def generate_random_inputs( + model: ModelProto, inputs: Dict[str, np.array] = None +) -> Dict[str, np.array]: input_values = {} # Iterate through model inputs and extract their shape. for i in model.graph.input: + if inputs is not None and i.name in inputs: + input_values[i.name] = inputs[i.name] + continue shape = [] for dim in i.type.tensor_type.shape.dim: shape.append(dim.dim_value) @@ -73,8 +78,7 @@ def check_correctness(model: ModelProto, inputs: Optional[Dict[str, np.array]] = """ # If inputs are not provided, extract them from the onnx graph and produce random # values that we'll use for testing. - if inputs is None: - inputs = generate_random_inputs(model) + inputs = generate_random_inputs(model, inputs) # Run the model through onnx to get the expected result. ort_session = onnxruntime.InferenceSession(model.SerializeToString()) @@ -167,6 +171,7 @@ def test_mul(): model = helper.make_model(graph, producer_name="mul_test") check_correctness(model) + check_correctness(model) def test_cast(): @@ -257,7 +262,7 @@ def test_div(): helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]), ], - outputs = [helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])] + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])], ) model = helper.make_model(graph, producer_name="div_test") @@ -271,7 +276,7 @@ def test_sigmoid(): [sigmoid_node], "sigmoid_test", inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], - outputs = [helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32])] + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32])], ) model = helper.make_model(graph, producer_name="sigmoid_test") @@ -285,7 +290,7 @@ def test_softmax(): [softmax_node], "softmax_test", inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32, 32])], - outputs = [helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32, 32])] + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32, 32])], ) model = helper.make_model(graph, producer_name="softmax_test") @@ -299,7 +304,7 @@ def test_transpose(): [transpose_node], "transpose_test", inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32, 32])], - outputs = [helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32, 32])] + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32, 32])], ) model = helper.make_model(graph, producer_name="transpose_test") @@ -314,7 +319,7 @@ def test_unsqueeze(): "unsqueeze", inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], initializer=[helper.make_tensor("axes", TensorProto.INT64, [3], vals=[0, 2, 3])], - outputs = [helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 1, 1, 32])] + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 1, 1, 32])], ) model = helper.make_model(graph, producer_name="unsqueeze_test") @@ -328,7 +333,7 @@ def test_gelu(): [gelu_node], "gelu_test", inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], - outputs = [helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32])] + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32])], ) model = helper.make_model(graph, producer_name="gelu_test") @@ -345,7 +350,7 @@ def test_bias_gelu(): helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), ], - outputs = [helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])], ) model = helper.make_model(graph, producer_name="bias_gelu_test") @@ -363,13 +368,206 @@ def test_where(): helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]), helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32]), ], - outputs = [helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32])], + outputs=[helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32])], ) model = helper.make_model(graph, producer_name="where_test") check_correctness(model) +def test_clip(): + clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"]) + + graph = helper.make_graph( + [clip_node], + "clip_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 64]), + helper.make_tensor_value_info("min", TensorProto.FLOAT, ()), + helper.make_tensor_value_info("max", TensorProto.FLOAT, ()), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 64])], + ) + + model = helper.make_model(graph, producer_name="clip_test") + check_correctness(model) + + +def test_equal(): + equal_node = helper.make_node("Equal", ["a", "b"], ["output"]) + + graph = helper.make_graph( + [equal_node], + "equal_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="equal_test") + check_correctness( + model, {"a": np.zeros([32, 32], dtype="float32"), "b": np.zeros([32, 32], dtype="float32")} + ) + check_correctness( + model, {"a": np.ones([32, 32], dtype="float32"), "b": np.zeros([32, 32], dtype="float32")} + ) + check_correctness(model) + + +def test_shape(): + shape_node = helper.make_node("Shape", ["data"], ["output"]) + + graph = helper.make_graph( + [shape_node], + "shape_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4, 5, 6]), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.INT64, [4])], + ) + + model = helper.make_model(graph, producer_name="shape_test") + check_correctness(model) + + +def test_not(): + not_node = helper.make_node("Not", ["x"], ["y"]) + shape = [3, 4, 5, 6] + graph = helper.make_graph( + [not_node], + "not_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.BOOL, shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.BOOL, shape)], + ) + + model = helper.make_model(graph, producer_name="not_test") + check_correctness(model, {"x": np.zeros(shape, dtype="bool")}) + check_correctness(model, {"x": np.ones(shape, dtype="bool")}) + + +def test_tanh(): + tanh_node = helper.make_node("Tanh", ["x"], ["y"]) + shape = [9, 8, 7, 6] + graph = helper.make_graph( + [tanh_node], + "tanh_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="tanh_test") + check_correctness(model) + + +def test_sqrt(): + sqrt_node = helper.make_node("Sqrt", ["x"], ["y"]) + shape = [32, 32] + graph = helper.make_graph( + [sqrt_node], + "sqrt_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="sqrt_test") + check_correctness(model) + + +def test_relu(): + relu_node = helper.make_node("Relu", ["x"], ["y"]) + shape = [32, 32] + graph = helper.make_graph( + [relu_node], + "relu_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="relu_test") + check_correctness(model) + + +def test_conv(): + conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"]) + nchw_shape = [3, 12, 32, 32] + graph = helper.make_graph( + [conv_node], + "conv_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, nchw_shape), + helper.make_tensor_value_info("w", TensorProto.FLOAT, [4, 12, 3, 3]), + helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4, 30, 30])], + ) + + model = helper.make_model(graph, producer_name="conv_test") + check_correctness(model) + + +def test_pow(): + pow_node = helper.make_node("Pow", ["x", "y"], ["z"]) + shape = [32, 32] + graph = helper.make_graph( + [pow_node], + "pow_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("y", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="pow_test") + check_correctness(model) + + +def test_erf(): + erf_node = helper.make_node("Erf", ["x"], ["y"]) + shape = [32, 32] + graph = helper.make_graph( + [erf_node], + "erf_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="erf_test") + check_correctness(model) + + +def test_cumsum(): + cumsum_node = helper.make_node("CumSum", ["x", "axis"], ["y"]) + shape = [32, 32] + type_proto = onnx.TypeProto() + tensor_type_proto = type_proto.tensor_type + tensor_type_proto.elem_type = TensorProto.INT64 + graph = helper.make_graph( + [cumsum_node], + "cumsum_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axis", TensorProto.INT64, ()), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="cumsum_test") + check_correctness(model, {"axis": [1]}) + + if __name__ == "__main__": test_matmul() test_concat() @@ -378,10 +576,22 @@ def test_where(): test_cast() test_gather() test_gemm() + test_equal() + test_not() + test_tanh() + test_sqrt() + test_relu() + test_clip() + test_conv() + test_pow() + test_erf() + # TODO, still has issues - #test_reshape() + # test_reshape() test_div() test_sigmoid() test_softmax() test_transpose() test_unsqueeze() + # test_shape() + # test_cumsum() # need axis as int