Skip to content

Commit

Permalink
#10034: Binary shift operators
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Dec 16, 2024
1 parent 388e187 commit d1e39e2
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 85 deletions.
46 changes: 45 additions & 1 deletion tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_mul_fp32(device, ttnn_function):
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.div,
ttnn.divide,
],
)
# Torch num/ 0 = inf and 0/0 nan; TT num/ 0 = inf and 0/0=nan; in fp32 tile
Expand Down Expand Up @@ -551,3 +551,47 @@ def test_bitwise_xor(device, ttnn_function):

status = ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.bitwise_left_shift,
],
)
def test_bitwise_left_shift(device, ttnn_function):
x_torch = torch.tensor([[99, 3, 100, 1, 72, 0]], dtype=torch.int32)
y_torch = torch.tensor([[1, 2, 31, 4, 5, 0]], dtype=torch.int32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_left_shift(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

status = ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.bitwise_right_shift,
],
)
def test_bitwise_right_shift(device, ttnn_function):
x_torch = torch.tensor([[19, 3, 101, 21, 47, 0]], dtype=torch.int32)
y_torch = torch.tensor([[5, 2, 31, 4, 5, 0]], dtype=torch.int32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_right_shift(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

status = ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,5 +478,7 @@ template struct BinaryOperationSfpu<BinaryOpType::POWER>;
template struct BinaryOperationSfpu<BinaryOpType::BITWISE_AND>;
template struct BinaryOperationSfpu<BinaryOpType::BITWISE_XOR>;
template struct BinaryOperationSfpu<BinaryOpType::BITWISE_OR>;
template struct BinaryOperationSfpu<BinaryOpType::LEFT_SHIFT>;
template struct BinaryOperationSfpu<BinaryOpType::RIGHT_SHIFT>;

} // namespace ttnn::operations::binary
60 changes: 60 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,62 @@ struct ExecuteBitwiseXor {
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

struct ExecuteBitwiseLeftShift {
static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
int32_t input_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
int32_t input_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

struct ExecuteBitwiseRightShift {
static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
int32_t input_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
int32_t input_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

} // namespace binary
} // namespace operations

Expand Down Expand Up @@ -499,6 +555,10 @@ constexpr auto rsub = ttnn::register_operation_with_auto_launch_op<"ttnn::rsub",
constexpr auto bitwise_and = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_and", operations::binary::ExecuteBitwiseAnd>();
constexpr auto bitwise_or = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_or", operations::binary::ExecuteBitwiseOr>();
constexpr auto bitwise_xor = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_xor", operations::binary::ExecuteBitwiseXor>();
constexpr auto bitwise_left_shift = ttnn::
register_operation_with_auto_launch_op<"ttnn::bitwise_left_shift", operations::binary::ExecuteBitwiseLeftShift>();
constexpr auto bitwise_right_shift = ttnn::
register_operation_with_auto_launch_op<"ttnn::bitwise_right_shift", operations::binary::ExecuteBitwiseRightShift>();
constexpr auto pow = ttnn::register_operation_with_auto_launch_op<"ttnn::pow", operations::binary::ExecutePower>();

} // namespace ttnn
16 changes: 16 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,22 @@ void py_module(py::module& module) {
". ",
R"doc(INT32)doc");

detail::bind_bitwise_binary_ops_operation(
module,
ttnn::bitwise_left_shift,
R"doc(Perform bitwise_left_shift operation on :attr:`input_tensor_a` by :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`. :attr:`input_tensor_b` has shift_bits which are integers within range (0, 31))doc",
R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_and|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc",
". ",
R"doc(INT32)doc");

detail::bind_bitwise_binary_ops_operation(
module,
ttnn::bitwise_right_shift,
R"doc(Perform bitwise_right_shift operation on :attr:`input_tensor_a` by :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`. :attr:`input_tensor_b` has shift_bits which are integers within range (0, 31))doc",
R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_and|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc",
". ",
R"doc(INT32)doc");

auto prim_module = module.def_submodule("prim", "Primitive binary operations");

detail::bind_primitive_binary_operation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ enum class BinaryOpType {
POWER,
BITWISE_XOR,
BITWISE_AND,
BITWISE_OR
BITWISE_OR,
LEFT_SHIFT,
RIGHT_SHIFT
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ std::map<std::string, std::string> get_defines_fp32(
new_defines.insert({"BITWISE_INIT", fmt::format("binary_bitwise_tile_init();")});
op_name = "xor_binary_tile";
break;
case BinaryOpType::LEFT_SHIFT:
new_defines.insert({"SHIFT_INIT", fmt::format("binary_shift_tile_init();")});
op_name = "binary_left_shift_tile";
break;
case BinaryOpType::RIGHT_SHIFT:
new_defines.insert({"SHIFT_INIT", fmt::format("binary_shift_tile_init();")});
op_name = "binary_right_shift_tile";
break;
case BinaryOpType::LOGADDEXP:
// PRE_IN0_0 ===> Applies prescaling for first input
// PRE_IN1_0 ====> Applies prescaling for second input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -900,4 +900,84 @@ Tensor ExecuteBitwiseXor::invoke(
std::move(optional_output_tensor));
}

// Bitwise Left Shift
Tensor ExecuteBitwiseLeftShift::invoke(
uint8_t queue_id,
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return BinaryOperationSfpu<operations::binary::BinaryOpType::LEFT_SHIFT>::invoke(
queue_id, input_tensor_a, input_tensor_b, std::nullopt, memory_config, optional_output_tensor);
}

Tensor ExecuteBitwiseLeftShift::invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return ExecuteBitwiseLeftShift::invoke(
ttnn::DefaultQueueId, input_tensor_a, input_tensor_b, memory_config, optional_output_tensor);
}

Tensor ExecuteBitwiseLeftShift::invoke(
uint8_t queue_id,
const Tensor& input_tensor_a,
const int32_t input_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return ttnn::operations::unary::
ExecuteUnaryWithIntegerParameter<ttnn::operations::unary::UnaryOpType::LEFT_SHIFT, int32_t>::invoke(
queue_id, input_tensor_a, input_b, memory_config, optional_output_tensor);
}

Tensor ExecuteBitwiseLeftShift::invoke(
const Tensor& input_tensor_a,
const int32_t input_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return ExecuteBitwiseLeftShift::invoke(
ttnn::DefaultQueueId, input_tensor_a, input_b, memory_config, std::move(optional_output_tensor));
}

// Bitwise Right Shift
Tensor ExecuteBitwiseRightShift::invoke(
uint8_t queue_id,
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return BinaryOperationSfpu<operations::binary::BinaryOpType::RIGHT_SHIFT>::invoke(
queue_id, input_tensor_a, input_tensor_b, std::nullopt, memory_config, optional_output_tensor);
}

Tensor ExecuteBitwiseRightShift::invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return ExecuteBitwiseRightShift::invoke(
ttnn::DefaultQueueId, input_tensor_a, input_tensor_b, memory_config, optional_output_tensor);
}

Tensor ExecuteBitwiseRightShift::invoke(
uint8_t queue_id,
const Tensor& input_tensor_a,
const int32_t input_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return ttnn::operations::unary::
ExecuteUnaryWithIntegerParameter<ttnn::operations::unary::UnaryOpType::RIGHT_SHIFT, int32_t>::invoke(
queue_id, input_tensor_a, input_b, memory_config, optional_output_tensor);
}

Tensor ExecuteBitwiseRightShift::invoke(
const Tensor& input_tensor_a,
const int32_t input_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor) {
return ExecuteBitwiseRightShift::invoke(
ttnn::DefaultQueueId, input_tensor_a, input_b, memory_config, std::move(optional_output_tensor));
}

} // namespace ttnn::operations::binary
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ namespace utils {
case BinaryOpType::LTE:
case BinaryOpType::EQ:
case BinaryOpType::NE: return (a == DataType::FLOAT32 && b == DataType::FLOAT32);
case BinaryOpType::LEFT_SHIFT:
case BinaryOpType::RIGHT_SHIFT:
case BinaryOpType::BITWISE_XOR:
case BinaryOpType::BITWISE_AND:
case BinaryOpType::BITWISE_OR: return (a == DataType::INT32 && b == DataType::INT32);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
#include "compute_kernel_api/eltwise_binary_sfpu.h"
#include "compute_kernel_api/binary_bitwise_sfpu.h"
#include "compute_kernel_api/binary_shift.h"
#include "compute_kernel_api/add_int32_sfpu.h"

#define PRE_SCALE defined SFPU_OP_INIT_PRE_IN0_0 || defined SFPU_OP_INIT_PRE_IN1_0

#if defined(ADD_INT32_INIT) || defined(BITWISE_INIT)
#if defined(ADD_INT32_INIT) || defined(BITWISE_INIT) || defined(SHIFT_INIT)
#define INT32_INIT
#endif

Expand Down Expand Up @@ -120,6 +121,9 @@ void MAIN {
#ifdef BITWISE_INIT
BITWISE_INIT
#endif
#ifdef SHIFT_INIT
SHIFT_INIT
#endif

#ifdef BINARY_SFPU_OP
BINARY_SFPU_OP
Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,6 @@ REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(ne_unary, UNARY_NE);

// Unaries with integer parameter
REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(power, POWER, uint32_t);
REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_left_shift, LEFT_SHIFT, int32_t);
REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_right_shift, RIGHT_SHIFT, int32_t);

// Other unaries
constexpr auto dropout =
Expand Down
Loading

0 comments on commit d1e39e2

Please sign in to comment.