From b63223818441362b454b238a94d9f95ee758f241 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 7 Oct 2024 13:51:55 +0800 Subject: [PATCH] [Relax] Support left_shift and right_shift op Introduced left_shift and right_shift op in Relax with ONNX frontend support. --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 102 ++++++++++++++++-- python/tvm/relax/op/__init__.py | 2 + python/tvm/relax/op/binary.py | 32 ++++++ .../relax/transform/legalize_ops/binary.py | 2 + python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/distributed/binary.cc | 2 + src/relax/op/tensor/binary.cc | 2 + src/relax/op/tensor/binary.h | 6 ++ tests/python/relax/test_frontend_onnx.py | 36 +++++++ tests/python/relax/test_op_binary.py | 2 + 10 files changed, 182 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5777f51fe296..861575cc1897 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -244,7 +244,7 @@ class BinaryBase(OnnxOpConverter): relax_op: Callable = None @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def base_impl(cls, bb, inputs, attr, params): if cls.numpy_op is None or cls.relax_op is None: raise ValueError("Numpy and Relax operators must be defined for BinaryBase.") if all([isinstance(inp, relax.Constant) for inp in inputs]): @@ -274,6 +274,10 @@ class Add(BinaryBase): numpy_op = _np.add relax_op = relax.op.add + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Sub(BinaryBase): """Converts an onnx Sub node into an equivalent Relax expression.""" @@ -281,6 +285,10 @@ class Sub(BinaryBase): numpy_op = _np.subtract relax_op = relax.op.subtract + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Mul(BinaryBase): """Converts an onnx Mul node into an equivalent Relax expression.""" @@ -288,6 +296,10 @@ class Mul(BinaryBase): numpy_op = _np.multiply relax_op = relax.op.multiply + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Div(BinaryBase): """Converts an onnx Div node into an equivalent Relax expression.""" @@ -295,6 +307,10 @@ class Div(BinaryBase): numpy_op = _np.divide relax_op = relax.op.divide + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Pow(BinaryBase): """Converts an onnx Pow node into an equivalent Relax expression.""" @@ -302,6 +318,10 @@ class Pow(BinaryBase): numpy_op = _np.power relax_op = relax.op.power + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class And(BinaryBase): """Converts an onnx And node into an equivalent Relax expression.""" @@ -309,6 +329,10 @@ class And(BinaryBase): numpy_op = _np.logical_and relax_op = relax.op.logical_and + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Or(BinaryBase): """Converts an onnx Or node into an equivalent Relax expression.""" @@ -316,6 +340,10 @@ class Or(BinaryBase): numpy_op = _np.logical_or relax_op = relax.op.logical_or + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Xor(BinaryBase): """Converts an onnx Xor node into an equivalent Relax expression.""" @@ -323,6 +351,10 @@ class Xor(BinaryBase): numpy_op = _np.logical_xor relax_op = relax.op.logical_xor + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Less(BinaryBase): """Converts an onnx Less node into an equivalent Relax expression.""" @@ -330,6 +362,10 @@ class Less(BinaryBase): numpy_op = _np.less relax_op = relax.op.less + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class LessOrEqual(BinaryBase): """Converts an onnx LessEqual node into an equivalent Relax expression.""" @@ -337,6 +373,10 @@ class LessOrEqual(BinaryBase): numpy_op = _np.less_equal relax_op = relax.op.less_equal + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Greater(BinaryBase): """Converts an onnx Greater node into an equivalent Relax expression.""" @@ -344,6 +384,10 @@ class Greater(BinaryBase): numpy_op = _np.greater relax_op = relax.op.greater + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class GreaterOrEqual(BinaryBase): """Converts an onnx GreaterEqual node into an equivalent Relax expression.""" @@ -351,6 +395,10 @@ class GreaterOrEqual(BinaryBase): numpy_op = _np.greater_equal relax_op = relax.op.greater_equal + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + class Equal(OnnxOpConverter): """Converts an onnx Equal node into an equivalent Relax expression.""" @@ -374,7 +422,7 @@ class BitwiseBase(BinaryBase): """Converts an onnx BitwiseBase node into an equivalent Relax expression.""" @classmethod - def base_impl(cls, bb, inputs, attr, params, py_func, relax_op): + def base_impl(cls, bb, inputs, attr, params): valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] for num, inp in enumerate(inputs): if inp.struct_info.dtype not in valid_types: @@ -382,31 +430,69 @@ def base_impl(cls, bb, inputs, attr, params, py_func, relax_op): f"Bitwise operations expect all inputs to have integer types, " f"got {inp.struct_info.dtype} for input {num}" ) - return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op) + return super().base_impl(bb, inputs, attr, params) class BitwiseAnd(BitwiseBase): """Converts an onnx BitwiseAnd node into an equivalent Relax expression.""" + numpy_op = _np.bitwise_and + relax_op = relax.op.bitwise_and + @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and) + return cls.base_impl(bb, inputs, attr, params) class BitwiseOr(BitwiseBase): """Converts an onnx BitwiseOr node into an equivalent Relax expression.""" + numpy_op = _np.bitwise_or + relax_op = relax.op.bitwise_or + @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or) + return cls.base_impl(bb, inputs, attr, params) class BitwiseXor(BitwiseBase): """Converts an onnx BitwiseXor node into an equivalent Relax expression.""" + numpy_op = _np.bitwise_xor + relax_op = relax.op.bitwise_xor + @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor) + return cls.base_impl(bb, inputs, attr, params) + + +class BitwiseNot(BitwiseBase): + """Converts an onnx BitwiseNot node into an equivalent Relax expression.""" + + numpy_op = _np.bitwise_not + relax_op = relax.op.bitwise_not + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + + +class BitShift(BitwiseBase): + """Converts an onnx BitShift node into an equivalent Relax expression.""" + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + direction = attr.get("direction", "LEFT").decode("ascii") + if direction == "LEFT": + cls.numpy_op = _np.left_shift + cls.relax_op = relax.op.left_shift + elif direction == "RIGHT": + cls.numpy_op = _np.right_shift + cls.relax_op = relax.op.right_shift + else: + raise ValueError("Unsupported Shift Direction: " + direction) + + return cls.base_impl(bb, inputs, attr, params) class Sigmoid(OnnxOpConverter): @@ -2652,8 +2738,8 @@ def _get_convert_map(): "BitwiseAnd": BitwiseAnd, "BitwiseOr": BitwiseOr, "BitwiseXor": BitwiseXor, - # "BitwiseNot": BitwiseNot, - # "BitwiseShift": BitwiseShift, + "BitwiseNot": BitwiseNot, + "BitShift": BitShift, "And": And, "Or": Or, "Xor": Xor, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 4581defa1a77..c99201e969b5 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -52,6 +52,7 @@ floor_divide, greater, greater_equal, + left_shift, less, less_equal, logical_and, @@ -62,6 +63,7 @@ multiply, not_equal, power, + right_shift, subtract, ) from .create import ( diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 982b3a24f26c..7632235cb32c 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr: The computed result. """ return _ffi_api.bitwise_xor(x1, x2) + + +def left_shift(x1: Expr, x2: Expr) -> Expr: + """Bitwise Shift Left + Parameters + ---------- + x1 : relax.Expr + The input tensor to be shifted. + x2 : relax.Expr + The number of positions to shift. + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.left_shift(x1, x2) + + +def right_shift(x1: Expr, x2: Expr) -> Expr: + """Bitwise Shift Right + Parameters + ---------- + x1 : relax.Expr + The input tensor to be shifted. + x2 : relax.Expr + The number of positions to shift. + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.right_shift(x1, x2) diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 16d6c0269616..d28e100edb9f 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -62,6 +62,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.bitwise_and", _binary(topi.bitwise_and)) register_legalize("relax.bitwise_or", _binary(topi.bitwise_or)) register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor)) +register_legalize("relax.left_shift", _binary(topi.left_shift)) +register_legalize("relax.right_shift", _binary(topi.right_shift)) # logical register_legalize("relax.logical_and", _binary(topi.logical_and)) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index c4be8afac4d2..e6ff35ebe56b 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -102,6 +102,7 @@ isinf, isnan, layout_transform, + left_shift, less, less_equal, linear, @@ -133,6 +134,7 @@ quantize, repeat, reshape, + right_shift, round, rsqrt, scatter_elements, @@ -773,6 +775,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isinf", "isnan", "layout_transform", + "left_shift", "less", "less_equal", "linear", @@ -809,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "repeat", "reshape", "rewriter", + "right_shift", "tensor_to_shape", "shape_to_tensor", "rocm", diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 63f4f356c03d..6ad71e0f85bf 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -68,6 +68,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift); } // namespace distributed } // namespace relax diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index afc0fb73031b..f1dc3d4904c8 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index b28a6c33690b..003bcb7e27cf 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2); /*! \brief Broadcasted element-wise bitwise xor */ Expr bitwise_xor(Expr x1, Expr x2); +/*! \brief Broadcasted element-wise bitwise shift left */ +Expr left_shift(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise bitwise shift right */ +Expr right_shift(Expr x1, Expr x2); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 2837ad2185e9..a3c1eb47f5b3 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -358,6 +358,42 @@ def test_binary_bool(op_name: str): verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.BOOL) +@pytest.mark.skip(reason="opset 18 is not supported in CI") +@pytest.mark.parametrize("op_name", ["BitwiseAnd", "BitwiseOr", "BitwiseXor"]) +def test_bitwise(op_name: str): + verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.UINT64, opset=18) + + +@pytest.mark.skip(reason="opset 18 is not supported in CI") +def test_bitwise_not(): + verify_unary( + "BitwiseNot", + [32, 32], + input_dtype=TensorProto.UINT64, + output_dtype=TensorProto.UINT64, + opset=18, + ) + + +@pytest.mark.parametrize("direction", ["LEFT", "RIGHT"]) +def test_bitwise_shift(direction: str): + shape = [32, 32] + dtype = TensorProto.UINT64 + test_node = helper.make_node("BitShift", ["a", "b"], ["c"], direction=direction) + graph = helper.make_graph( + [test_node], + "binary_test", + inputs=[ + helper.make_tensor_value_info("a", dtype, shape), + helper.make_tensor_value_info("b", dtype, shape), + ], + outputs=[helper.make_tensor_value_info("c", dtype, shape)], + ) + + model = helper.make_model(graph, producer_name="binary_test") + check_correctness(model, inputs={"b": np.random.randint(0, 8, shape).astype("uint64")}) + + @pytest.mark.parametrize( "op_name", [ diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 85842f1578df..20c111495d6a 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -46,6 +46,8 @@ def test_op_correctness(): assert relax.op.bitwise_and(x, y).op == Op.get("relax.bitwise_and") assert relax.op.bitwise_or(x, y).op == Op.get("relax.bitwise_or") assert relax.op.bitwise_xor(x, y).op == Op.get("relax.bitwise_xor") + assert relax.op.left_shift(x, y).op == Op.get("relax.left_shift") + assert relax.op.right_shift(x, y).op == Op.get("relax.right_shift") x = relax.Var("x", R.Tensor((2, 3), "bool")) y = relax.Var("y", R.Tensor((2, 3), "bool"))