diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index b9956a45fadd..8362d4717ca9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -17,7 +17,7 @@ namespace detail { constexpr bool is_associative(BinaryOpType op) { return op == BinaryOpType::ADD || op == BinaryOpType::MUL || op == BinaryOpType::EQ || op == BinaryOpType::NE || op == BinaryOpType::LOGICAL_AND || op == BinaryOpType::LOGICAL_OR || op == BinaryOpType::LOGADDEXP || - op == BinaryOpType::LOGADDEXP2; + op == BinaryOpType::LOGADDEXP2 || op == BinaryOpType::LOGICAL_XOR; } // Tensor - Scalar @@ -387,6 +387,7 @@ template struct BinaryOperation; template struct InplaceBinaryOperation; template struct BinaryOperation; template struct BinaryOperation; +template struct BinaryOperation; template struct BinaryOperation; template struct BinaryOperation; template struct BinaryOperation; @@ -410,5 +411,7 @@ template struct InplaceRelationalBinary; template struct InplaceLogicalBinary; template struct InplaceLogicalBinary; +template struct InplaceLogicalBinary; + } // namespace ttnn::operations::binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index 90d1babf8278..5287ac6829d8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -187,6 +187,9 @@ constexpr auto logical_and = ttnn::register_operation_with_auto_launch_op< constexpr auto logical_or = ttnn::register_operation_with_auto_launch_op< "ttnn::logical_or", operations::binary::BinaryOperation>(); +constexpr auto logical_xor = ttnn::register_operation_with_auto_launch_op< + "ttnn::logical_xor", + operations::binary::BinaryOperation>(); constexpr auto ldexp = ttnn::register_operation_with_auto_launch_op< "ttnn::ldexp", operations::binary::BinaryOperation>(); @@ -220,6 +223,9 @@ constexpr auto logical_and_ = ttnn::register_operation_with_auto_launch_op< constexpr auto logical_or_ = ttnn::register_operation_with_auto_launch_op< "ttnn::logical_or_", operations::binary::InplaceLogicalBinary>(); +constexpr auto logical_xor_ = ttnn::register_operation_with_auto_launch_op< + "ttnn::logical_xor_", + operations::binary::InplaceLogicalBinary>(); constexpr auto eq_ = ttnn::register_operation_with_auto_launch_op< "ttnn::eq_", operations::binary::InplaceRelationalBinary>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index 9ff84b3a2321..6da21e63a87b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -233,9 +233,6 @@ constexpr auto maximum = ttnn::register_operation_with_auto_launch_op< constexpr auto atan2 = ttnn::register_operation_with_auto_launch_op< "ttnn::atan2", operations::binary::ExecuteBinaryCompositeOps>(); -constexpr auto logical_xor = ttnn::register_operation_with_auto_launch_op< - "ttnn::logical_xor", - operations::binary::ExecuteBinaryCompositeOps>(); constexpr auto nextafter = ttnn::register_operation_with_auto_launch_op< "ttnn::nextafter", operations::binary::ExecuteBinaryCompositeOps>(); @@ -263,9 +260,6 @@ constexpr auto div_no_nan = ttnn::register_operation_with_auto_launch_op< constexpr auto floor_div = ttnn::register_operation_with_auto_launch_op< "ttnn::floor_div", operations::binary::ExecuteDivLikeOps>(); -constexpr auto logical_xor_ = ttnn::register_operation_with_auto_launch_op< - "ttnn::logical_xor_", - operations::binary::ExecuteBinaryCompositeOps>(); constexpr auto bias_gelu = ttnn::register_operation_with_auto_launch_op< "ttnn::bias_gelu", operations::binary::ExecuteBiasGelu>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index ed54ef6ab593..8c6bb9291a67 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -976,7 +976,7 @@ void py_module(py::module& module) { R"doc(\mathrm{output\_tensor}_i = \arctan\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right) )doc"); - detail::bind_binary_composite( + detail::bind_binary_operation( module, ttnn::logical_xor, R"doc(Compute logical_xor :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", @@ -1006,12 +1006,11 @@ void py_module(py::module& module) { +----------------------------+---------------------------------+-------------------+ )doc"); - detail::bind_binary_composite( + detail::bind_logical_inplace_operation( module, ttnn::logical_xor_, R"doc(Compute inplace logical XOR of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", - R"doc((\mathrm{input\_tensor\_a}_i \land \lnot \mathrm{input\_tensor\_b}_i) \lor (\lnot \mathrm{input\_tensor\_a}_i \land \mathrm{input\_tensor\_b}_i) - )doc", + R"doc(Supported dtypes, layouts, and ranks: +----------------------------+---------------------------------+-------------------+ diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp index 125eda1197c1..db1a36836aed 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp @@ -21,6 +21,7 @@ enum class BinaryOpType { LOGADDEXP, LOGICAL_AND, LOGICAL_OR, + LOGICAL_XOR, LDEXP, LOGADDEXP2, DIV_FAST diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp index 71cec80b854b..512bd4717b5d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp @@ -94,6 +94,13 @@ std::map get_defines( op_binary_type = "EltwiseBinaryType::ELWADD"; defines.merge(get_defines(UnaryOpType::GTZ, std::nullopt, "0", idst)); break; + case BinaryOpType::LOGICAL_XOR: + defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN0_0")); + defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN1_0")); + op_name = "sub_tiles"; + op_binary_type = "EltwiseBinaryType::ELWSUB"; + defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "0", idst)); + break; case BinaryOpType::LDEXP: defines.merge(get_defines(UnaryOpType::EXP2, std::nullopt, "PRE_IN1_0")); op_name = "mul_tiles"; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 96936013a3b1..c0a516b34cb6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -151,13 +151,13 @@ Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional< return res; } -Tensor _logical_xor(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { - Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config); - Tensor in_b_eq_zero = ttnn::eqz(input_b, output_mem_config); - Tensor in_b_neq_zero = ttnn::nez(input_b, output_mem_config); - Tensor result = ttnn::where(in_a_eq_zero, in_b_neq_zero, in_b_eq_zero); - return result; -} +// Tensor _logical_xor(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { +// Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config); +// Tensor in_b_eq_zero = ttnn::eqz(input_b, output_mem_config); +// Tensor in_b_neq_zero = ttnn::nez(input_b, output_mem_config); +// Tensor result = ttnn::where(in_a_eq_zero, in_b_neq_zero, in_b_eq_zero); +// return result; +//} Tensor ExecuteDiv::invoke(uint8_t queue_id, const Tensor& input, float value, bool accurate_mode, const std::string& round_mode, const std::optional& output_mem_config, std::optional output_tensor) { TT_FATAL((round_mode == "None" || round_mode == "trunc" || round_mode == "floor"), "Incorrect rounding mode (expected 'None', 'trunc', or 'floor')"); @@ -324,13 +324,13 @@ Tensor _floor_div(const Tensor& input_a, const Tensor& input_b, const std::optio result); } -Tensor _logical_xor_(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { - Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config, input_a ); - Tensor in_b_eq_zero = ttnn::nez(input_b, output_mem_config, input_b ); - in_b_eq_zero = ttnn::eqz(input_b, output_mem_config); - Tensor result = ttnn::where(input_a, input_b, in_b_eq_zero, output_mem_config, input_a); - return result; -} +// Tensor _logical_xor_(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { +// Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config, input_a ); +// Tensor in_b_eq_zero = ttnn::nez(input_b, output_mem_config, input_b ); +// in_b_eq_zero = ttnn::eqz(input_b, output_mem_config); +// Tensor result = ttnn::where(input_a, input_b, in_b_eq_zero, output_mem_config, input_a); +// return result; +// } Tensor _scatter(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { tt::tt_metal::Array4D start_index = {0, 0, 0, 0}; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp index 0c0b7f07314c..338be5c1a9ed 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp @@ -23,10 +23,8 @@ enum class BinaryCompositeOpType { MINIMUM, MAXIMUM, ATAN2, - LOGICAL_XOR, DIV_NO_NAN, FLOOR_DIV, - LOGICAL_XOR_, SCATTER, OUTER, POLYVAL, @@ -37,7 +35,6 @@ Tensor _xlogy(const Tensor&, const Tensor&, const std::optional&); Tensor _minimum(const Tensor&, const Tensor&, const std::optional&); Tensor _maximum(const Tensor&, const Tensor&, const std::optional&); Tensor _atan2(const Tensor&, const Tensor&, const std::optional&); -Tensor _logical_xor(const Tensor&, const Tensor&, const std::optional&); Tensor _nextafter(const Tensor&, const Tensor&, const std::optional&); Tensor _addalpha(const Tensor&, const Tensor&, float, const std::optional&); Tensor _subalpha(const Tensor&, const Tensor&, float, const std::optional&); @@ -46,7 +43,6 @@ Tensor _div_no_nan(const Tensor&, const Tensor&, const std::optional&); Tensor _floor_div(const Tensor&, const Tensor&, const std::optional&); Tensor _floor_div_overload(const Tensor&, float, const std::optional&); -Tensor _logical_xor_(const Tensor&, const Tensor&, const std::optional&); Tensor _scatter(const Tensor&, const Tensor&, const std::optional&); Tensor _outer(const Tensor&, const Tensor&, const std::optional&); Tensor _polyval(const Tensor&, const std::vector&, const std::optional&); @@ -98,20 +94,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional& mem_cfg) { - return _logical_xor(t1, t2, mem_cfg); - } -}; - -template <> -struct OpHandler { - static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional& mem_cfg) { - return _logical_xor_(t1, t2, mem_cfg); - } -}; - template <> struct OpHandler { static Tensor handle(const Tensor& t1, const Tensor& t2, float alpha, const std::optional& mem_cfg) {