Skip to content

Commit

Permalink
#13758: Update logical xor op and remove composite version
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw authored and Aswinmcw committed Oct 29, 2024
1 parent 7c65c01 commit 679927d
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 43 deletions.
5 changes: 4 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace detail {
constexpr bool is_associative(BinaryOpType op) {
return op == BinaryOpType::ADD || op == BinaryOpType::MUL || op == BinaryOpType::EQ || op == BinaryOpType::NE ||
op == BinaryOpType::LOGICAL_AND || op == BinaryOpType::LOGICAL_OR || op == BinaryOpType::LOGADDEXP ||
op == BinaryOpType::LOGADDEXP2;
op == BinaryOpType::LOGADDEXP2 || op == BinaryOpType::LOGICAL_XOR;
}

// Tensor - Scalar
Expand Down Expand Up @@ -387,6 +387,7 @@ template struct BinaryOperation<BinaryOpType::MUL>;
template struct InplaceBinaryOperation<BinaryOpType::MUL>;
template struct BinaryOperation<BinaryOpType::LOGICAL_AND>;
template struct BinaryOperation<BinaryOpType::LOGICAL_OR>;
template struct BinaryOperation<BinaryOpType::LOGICAL_XOR>;
template struct BinaryOperation<BinaryOpType::LDEXP>;
template struct BinaryOperation<BinaryOpType::LOGADDEXP>;
template struct BinaryOperation<BinaryOpType::LOGADDEXP2>;
Expand All @@ -410,5 +411,7 @@ template struct InplaceRelationalBinary<BinaryOpType::NE>;

template struct InplaceLogicalBinary<BinaryOpType::LOGICAL_AND>;
template struct InplaceLogicalBinary<BinaryOpType::LOGICAL_OR>;
template struct InplaceLogicalBinary<BinaryOpType::LOGICAL_XOR>;


} // namespace ttnn::operations::binary
6 changes: 6 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ constexpr auto logical_and = ttnn::register_operation_with_auto_launch_op<
constexpr auto logical_or = ttnn::register_operation_with_auto_launch_op<
"ttnn::logical_or",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::LOGICAL_OR>>();
constexpr auto logical_xor = ttnn::register_operation_with_auto_launch_op<
"ttnn::logical_xor",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::LOGICAL_XOR>>();
constexpr auto ldexp = ttnn::register_operation_with_auto_launch_op<
"ttnn::ldexp",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::LDEXP>>();
Expand Down Expand Up @@ -220,6 +223,9 @@ constexpr auto logical_and_ = ttnn::register_operation_with_auto_launch_op<
constexpr auto logical_or_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::logical_or_",
operations::binary::InplaceLogicalBinary<operations::binary::BinaryOpType::LOGICAL_OR>>();
constexpr auto logical_xor_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::logical_xor_",
operations::binary::InplaceLogicalBinary<operations::binary::BinaryOpType::LOGICAL_XOR>>();
constexpr auto eq_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::eq_",
operations::binary::InplaceRelationalBinary<operations::binary::BinaryOpType::EQ>>();
Expand Down
6 changes: 0 additions & 6 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ constexpr auto maximum = ttnn::register_operation_with_auto_launch_op<
constexpr auto atan2 = ttnn::register_operation_with_auto_launch_op<
"ttnn::atan2",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::ATAN2>>();
constexpr auto logical_xor = ttnn::register_operation_with_auto_launch_op<
"ttnn::logical_xor",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::LOGICAL_XOR>>();
constexpr auto nextafter = ttnn::register_operation_with_auto_launch_op<
"ttnn::nextafter",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::NEXTAFTER>>();
Expand Down Expand Up @@ -263,9 +260,6 @@ constexpr auto div_no_nan = ttnn::register_operation_with_auto_launch_op<
constexpr auto floor_div = ttnn::register_operation_with_auto_launch_op<
"ttnn::floor_div",
operations::binary::ExecuteDivLikeOps<operations::binary::BinaryCompositeOpType::FLOOR_DIV>>();
constexpr auto logical_xor_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::logical_xor_",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::LOGICAL_XOR_>>();
constexpr auto bias_gelu = ttnn::register_operation_with_auto_launch_op<
"ttnn::bias_gelu",
operations::binary::ExecuteBiasGelu<operations::binary::BinaryOpType::BIAS_GELU>>();
Expand Down
7 changes: 3 additions & 4 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ void py_module(py::module& module) {
R"doc(\mathrm{output\_tensor}_i = \arctan\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)
)doc");

detail::bind_binary_composite(
detail::bind_binary_operation(
module,
ttnn::logical_xor,
R"doc(Compute logical_xor :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
Expand Down Expand Up @@ -1006,12 +1006,11 @@ void py_module(py::module& module) {
+----------------------------+---------------------------------+-------------------+
)doc");

detail::bind_binary_composite(
detail::bind_logical_inplace_operation(
module,
ttnn::logical_xor_,
R"doc(Compute inplace logical XOR of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc((\mathrm{input\_tensor\_a}_i \land \lnot \mathrm{input\_tensor\_b}_i) \lor (\lnot \mathrm{input\_tensor\_a}_i \land \mathrm{input\_tensor\_b}_i)
)doc",

R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ enum class BinaryOpType {
LOGADDEXP,
LOGICAL_AND,
LOGICAL_OR,
LOGICAL_XOR,
LDEXP,
LOGADDEXP2,
DIV_FAST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ std::map<std::string, std::string> get_defines(
op_binary_type = "EltwiseBinaryType::ELWADD";
defines.merge(get_defines(UnaryOpType::GTZ, std::nullopt, "0", idst));
break;
case BinaryOpType::LOGICAL_XOR:
defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN0_0"));
defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN1_0"));
op_name = "sub_tiles";
op_binary_type = "EltwiseBinaryType::ELWSUB";
defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "0", idst));
break;
case BinaryOpType::LDEXP:
defines.merge(get_defines(UnaryOpType::EXP2, std::nullopt, "PRE_IN1_0"));
op_name = "mul_tiles";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional<
return res;
}

Tensor _logical_xor(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config);
Tensor in_b_eq_zero = ttnn::eqz(input_b, output_mem_config);
Tensor in_b_neq_zero = ttnn::nez(input_b, output_mem_config);
Tensor result = ttnn::where(in_a_eq_zero, in_b_neq_zero, in_b_eq_zero);
return result;
}
// Tensor _logical_xor(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
// Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config);
// Tensor in_b_eq_zero = ttnn::eqz(input_b, output_mem_config);
// Tensor in_b_neq_zero = ttnn::nez(input_b, output_mem_config);
// Tensor result = ttnn::where(in_a_eq_zero, in_b_neq_zero, in_b_eq_zero);
// return result;
//}

Tensor ExecuteDiv::invoke(uint8_t queue_id, const Tensor& input, float value, bool accurate_mode, const std::string& round_mode, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> output_tensor) {
TT_FATAL((round_mode == "None" || round_mode == "trunc" || round_mode == "floor"), "Incorrect rounding mode (expected 'None', 'trunc', or 'floor')");
Expand Down Expand Up @@ -324,13 +324,13 @@ Tensor _floor_div(const Tensor& input_a, const Tensor& input_b, const std::optio
result);
}

Tensor _logical_xor_(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config, input_a );
Tensor in_b_eq_zero = ttnn::nez(input_b, output_mem_config, input_b );
in_b_eq_zero = ttnn::eqz(input_b, output_mem_config);
Tensor result = ttnn::where(input_a, input_b, in_b_eq_zero, output_mem_config, input_a);
return result;
}
// Tensor _logical_xor_(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
// Tensor in_a_eq_zero = ttnn::eqz(input_a, output_mem_config, input_a );
// Tensor in_b_eq_zero = ttnn::nez(input_b, output_mem_config, input_b );
// in_b_eq_zero = ttnn::eqz(input_b, output_mem_config);
// Tensor result = ttnn::where(input_a, input_b, in_b_eq_zero, output_mem_config, input_a);
// return result;
// }

Tensor _scatter(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
tt::tt_metal::Array4D start_index = {0, 0, 0, 0};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ enum class BinaryCompositeOpType {
MINIMUM,
MAXIMUM,
ATAN2,
LOGICAL_XOR,
DIV_NO_NAN,
FLOOR_DIV,
LOGICAL_XOR_,
SCATTER,
OUTER,
POLYVAL,
Expand All @@ -37,7 +35,6 @@ Tensor _xlogy(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _minimum(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _maximum(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _atan2(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _logical_xor(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _nextafter(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _addalpha(const Tensor&, const Tensor&, float, const std::optional<MemoryConfig>&);
Tensor _subalpha(const Tensor&, const Tensor&, float, const std::optional<MemoryConfig>&);
Expand All @@ -46,7 +43,6 @@ Tensor _div_no_nan(const Tensor&, const Tensor&, const std::optional<MemoryConfi
Tensor _div_no_nan_overload(const Tensor&, float, const std::optional<MemoryConfig>&);
Tensor _floor_div(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _floor_div_overload(const Tensor&, float, const std::optional<MemoryConfig>&);
Tensor _logical_xor_(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _scatter(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _outer(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _polyval(const Tensor&, const std::vector<float>&, const std::optional<MemoryConfig>&);
Expand Down Expand Up @@ -98,20 +94,6 @@ struct OpHandler<BinaryCompositeOpType::ATAN2> {
}
};

template <>
struct OpHandler<BinaryCompositeOpType::LOGICAL_XOR> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional<MemoryConfig>& mem_cfg) {
return _logical_xor(t1, t2, mem_cfg);
}
};

template <>
struct OpHandler<BinaryCompositeOpType::LOGICAL_XOR_> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional<MemoryConfig>& mem_cfg) {
return _logical_xor_(t1, t2, mem_cfg);
}
};

template <>
struct OpHandler<BinaryCompositeOpType::ADDALPHA> {
static Tensor handle(const Tensor& t1, const Tensor& t2, float alpha, const std::optional<MemoryConfig>& mem_cfg) {
Expand Down

0 comments on commit 679927d

Please sign in to comment.