Skip to content

Commit

Permalink
#13864: Add forward support for maximum op with tensor, scalar variant
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Oct 30, 2024
1 parent e05135b commit 3fb2ad2
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 21 deletions.
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ def test_binary_maximum_ttnn(input_shapes, device):
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"scalar",
{random.randint(-100, 100) + 0.5 for _ in range(5)},
)
def test_binary_maximum_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)

output_tensor = ttnn.maximum(input_tensor1, scalar)
golden_function = ttnn.get_golden_function(ttnn.maximum)
golden_tensor = golden_function(in_data1, torch.tensor(scalar))

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
16 changes: 15 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,20 @@ struct ExecuteGCD {
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};

struct ExecuteMaximum
{
static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

};

} // namespace binary
} // namespace operations

Expand All @@ -229,7 +243,7 @@ constexpr auto minimum = ttnn::register_operation_with_auto_launch_op<
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::MINIMUM>>();
constexpr auto maximum = ttnn::register_operation_with_auto_launch_op<
"ttnn::maximum",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::MAXIMUM>>();
operations::binary::ExecuteMaximum>();
constexpr auto atan2 = ttnn::register_operation_with_auto_launch_op<
"ttnn::atan2",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::ATAN2>>();
Expand Down
18 changes: 8 additions & 10 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ void bind_binary_composite_with_rtol_atol(py::module& module, const binary_opera
}

template <typename binary_operation_t>
void bind_div_like_ops(py::module& module, const binary_operation_t& operation, const std::string& description) {
void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc(
{2}
Expand Down Expand Up @@ -963,13 +963,6 @@ void py_module(py::module& module) {
R"doc(\mathrm{output\_tensor}_i = \text{min}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right)
)doc");

detail::bind_binary_composite(
module,
ttnn::maximum,
R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(\mathrm{output\_tensor}_i = \text{max}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right)
)doc");

detail::bind_binary_composite(
module,
ttnn::atan2,
Expand Down Expand Up @@ -1080,16 +1073,21 @@ void py_module(py::module& module) {
R"doc(\mathrm{output}_i = \begin{cases} \mathrm{\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)}, & \text{if } \mathrm{round\_mode} = \mathrm{None} \\ \mathrm{\text{floor}\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)}, & \text{if } \mathrm{round\_mode} = \mathrm{floor} \\ \mathrm{\text{trunc}\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)}, & \text{if } \mathrm{round\_mode} = \mathrm{trunc} \end{cases}
)doc");

detail::bind_div_like_ops(
detail::bind_binary_composite_overload(
module,
ttnn::div_no_nan,
R"doc(Compute div_no_nan :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_div_like_ops(
detail::bind_binary_composite_overload(
module,
ttnn::floor_div,
R"doc(Compute floor division :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_binary_composite_overload(
module,
ttnn::maximum,
R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_binary_composite(
module,
ttnn::scatter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,18 @@ Tensor _minimum(const Tensor& input_a, const Tensor& input_b, const std::optiona
}

// maximum(a,b) = a + (b - a > 0 )*(b-a)
Tensor _maximum(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor ExecuteMaximum::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor t_diff = ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config);
Tensor result = ttnn::where(t_diff, input_b, input_a);
return result;
}

Tensor ExecuteMaximum::invoke(const Tensor& input_a, float value, const std::optional<MemoryConfig>& output_mem_config) {
Tensor t_diff = ttnn::rsub(input_a, value, output_mem_config);
Tensor result = ttnn::where(t_diff, value, input_a);
return result;
}

Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor result(input_a);
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ enum class BinaryCompositeOpType {
NEXTAFTER,
ISCLOSE,
MINIMUM,
MAXIMUM,
ATAN2,
DIV_NO_NAN,
FLOOR_DIV,
Expand All @@ -33,7 +32,6 @@ enum class BinaryCompositeOpType {
Tensor _hypot(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _xlogy(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _minimum(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _maximum(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _atan2(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _nextafter(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _addalpha(const Tensor&, const Tensor&, float, const std::optional<MemoryConfig>&);
Expand Down Expand Up @@ -80,13 +78,6 @@ struct OpHandler<BinaryCompositeOpType::MINIMUM> {
}
};

template <>
struct OpHandler<BinaryCompositeOpType::MAXIMUM> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional<MemoryConfig>& mem_cfg) {
return _maximum(t1, t2, mem_cfg);
}
};

template <>
struct OpHandler<BinaryCompositeOpType::ATAN2> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional<MemoryConfig>& mem_cfg) {
Expand Down

0 comments on commit 3fb2ad2

Please sign in to comment.