Skip to content

Commit

Permalink
[Relax] Support left_shift and right_shift op
Browse files Browse the repository at this point in the history
Introduced left_shift and right_shift op in Relax with ONNX frontend
support.
  • Loading branch information
Hzfengsy committed Oct 7, 2024
1 parent accd582 commit b632238
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 8 deletions.
102 changes: 94 additions & 8 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class BinaryBase(OnnxOpConverter):
relax_op: Callable = None

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
def base_impl(cls, bb, inputs, attr, params):
if cls.numpy_op is None or cls.relax_op is None:
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
Expand Down Expand Up @@ -274,83 +274,131 @@ class Add(BinaryBase):
numpy_op = _np.add
relax_op = relax.op.add

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Sub(BinaryBase):
"""Converts an onnx Sub node into an equivalent Relax expression."""

numpy_op = _np.subtract
relax_op = relax.op.subtract

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Mul(BinaryBase):
"""Converts an onnx Mul node into an equivalent Relax expression."""

numpy_op = _np.multiply
relax_op = relax.op.multiply

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Div(BinaryBase):
"""Converts an onnx Div node into an equivalent Relax expression."""

numpy_op = _np.divide
relax_op = relax.op.divide

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Pow(BinaryBase):
"""Converts an onnx Pow node into an equivalent Relax expression."""

numpy_op = _np.power
relax_op = relax.op.power

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class And(BinaryBase):
"""Converts an onnx And node into an equivalent Relax expression."""

numpy_op = _np.logical_and
relax_op = relax.op.logical_and

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Or(BinaryBase):
"""Converts an onnx Or node into an equivalent Relax expression."""

numpy_op = _np.logical_or
relax_op = relax.op.logical_or

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Xor(BinaryBase):
"""Converts an onnx Xor node into an equivalent Relax expression."""

numpy_op = _np.logical_xor
relax_op = relax.op.logical_xor

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Less(BinaryBase):
"""Converts an onnx Less node into an equivalent Relax expression."""

numpy_op = _np.less
relax_op = relax.op.less

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class LessOrEqual(BinaryBase):
"""Converts an onnx LessEqual node into an equivalent Relax expression."""

numpy_op = _np.less_equal
relax_op = relax.op.less_equal

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Greater(BinaryBase):
"""Converts an onnx Greater node into an equivalent Relax expression."""

numpy_op = _np.greater
relax_op = relax.op.greater

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class GreaterOrEqual(BinaryBase):
"""Converts an onnx GreaterEqual node into an equivalent Relax expression."""

numpy_op = _np.greater_equal
relax_op = relax.op.greater_equal

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Equal(OnnxOpConverter):
"""Converts an onnx Equal node into an equivalent Relax expression."""
Expand All @@ -374,39 +422,77 @@ class BitwiseBase(BinaryBase):
"""Converts an onnx BitwiseBase node into an equivalent Relax expression."""

@classmethod
def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
def base_impl(cls, bb, inputs, attr, params):
valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
for num, inp in enumerate(inputs):
if inp.struct_info.dtype not in valid_types:
raise ValueError(
f"Bitwise operations expect all inputs to have integer types, "
f"got {inp.struct_info.dtype} for input {num}"
)
return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op)
return super().base_impl(bb, inputs, attr, params)


class BitwiseAnd(BitwiseBase):
"""Converts an onnx BitwiseAnd node into an equivalent Relax expression."""

numpy_op = _np.bitwise_and
relax_op = relax.op.bitwise_and

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and)
return cls.base_impl(bb, inputs, attr, params)


class BitwiseOr(BitwiseBase):
"""Converts an onnx BitwiseOr node into an equivalent Relax expression."""

numpy_op = _np.bitwise_or
relax_op = relax.op.bitwise_or

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or)
return cls.base_impl(bb, inputs, attr, params)


class BitwiseXor(BitwiseBase):
"""Converts an onnx BitwiseXor node into an equivalent Relax expression."""

numpy_op = _np.bitwise_xor
relax_op = relax.op.bitwise_xor

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor)
return cls.base_impl(bb, inputs, attr, params)


class BitwiseNot(BitwiseBase):
"""Converts an onnx BitwiseNot node into an equivalent Relax expression."""

numpy_op = _np.bitwise_not
relax_op = relax.op.bitwise_not

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class BitShift(BitwiseBase):
"""Converts an onnx BitShift node into an equivalent Relax expression."""

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
direction = attr.get("direction", "LEFT").decode("ascii")
if direction == "LEFT":
cls.numpy_op = _np.left_shift
cls.relax_op = relax.op.left_shift
elif direction == "RIGHT":
cls.numpy_op = _np.right_shift
cls.relax_op = relax.op.right_shift
else:
raise ValueError("Unsupported Shift Direction: " + direction)

return cls.base_impl(bb, inputs, attr, params)


class Sigmoid(OnnxOpConverter):
Expand Down Expand Up @@ -2652,8 +2738,8 @@ def _get_convert_map():
"BitwiseAnd": BitwiseAnd,
"BitwiseOr": BitwiseOr,
"BitwiseXor": BitwiseXor,
# "BitwiseNot": BitwiseNot,
# "BitwiseShift": BitwiseShift,
"BitwiseNot": BitwiseNot,
"BitShift": BitShift,
"And": And,
"Or": Or,
"Xor": Xor,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
floor_divide,
greater,
greater_equal,
left_shift,
less,
less_equal,
logical_and,
Expand All @@ -62,6 +63,7 @@
multiply,
not_equal,
power,
right_shift,
subtract,
)
from .create import (
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relax/op/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
The computed result.
"""
return _ffi_api.bitwise_xor(x1, x2)


def left_shift(x1: Expr, x2: Expr) -> Expr:
"""Bitwise Shift Left
Parameters
----------
x1 : relax.Expr
The input tensor to be shifted.
x2 : relax.Expr
The number of positions to shift.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.left_shift(x1, x2)


def right_shift(x1: Expr, x2: Expr) -> Expr:
"""Bitwise Shift Right
Parameters
----------
x1 : relax.Expr
The input tensor to be shifted.
x2 : relax.Expr
The number of positions to shift.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.right_shift(x1, x2)
2 changes: 2 additions & 0 deletions python/tvm/relax/transform/legalize_ops/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
register_legalize("relax.left_shift", _binary(topi.left_shift))
register_legalize("relax.right_shift", _binary(topi.right_shift))

# logical
register_legalize("relax.logical_and", _binary(topi.logical_and))
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
isinf,
isnan,
layout_transform,
left_shift,
less,
less_equal,
linear,
Expand Down Expand Up @@ -133,6 +134,7 @@
quantize,
repeat,
reshape,
right_shift,
round,
rsqrt,
scatter_elements,
Expand Down Expand Up @@ -773,6 +775,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"isinf",
"isnan",
"layout_transform",
"left_shift",
"less",
"less_equal",
"linear",
Expand Down Expand Up @@ -809,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"repeat",
"reshape",
"rewriter",
"right_shift",
"tensor_to_shape",
"shape_to_tensor",
"rocm",
Expand Down
2 changes: 2 additions & 0 deletions src/relax/op/distributed/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift);

} // namespace distributed
} // namespace relax
Expand Down
2 changes: 2 additions & 0 deletions src/relax/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift);

} // namespace relax
} // namespace tvm
6 changes: 6 additions & 0 deletions src/relax/op/tensor/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2);
/*! \brief Broadcasted element-wise bitwise xor */
Expr bitwise_xor(Expr x1, Expr x2);

/*! \brief Broadcasted element-wise bitwise shift left */
Expr left_shift(Expr x1, Expr x2);

/*! \brief Broadcasted element-wise bitwise shift right */
Expr right_shift(Expr x1, Expr x2);

} // namespace relax
} // namespace tvm

Expand Down
Loading

0 comments on commit b632238

Please sign in to comment.