From 3fb2ad23d5ba182fd15a1637c3b6eaea67876ef7 Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Sat, 19 Oct 2024 14:28:28 +0000 Subject: [PATCH] #13864: Add forward support for maximum op with tensor, scalar variant --- .../eltwise/test_binary_composite.py | 23 +++++++++++++++++++ .../eltwise/binary/binary_composite.hpp | 16 ++++++++++++- .../eltwise/binary/binary_pybind.hpp | 18 +++++++-------- .../binary/device/binary_composite_op.cpp | 8 ++++++- .../binary/device/binary_composite_op.hpp | 9 -------- 5 files changed, 53 insertions(+), 21 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 89e74e1f85b1..e84a386bbfb2 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -139,6 +139,29 @@ def test_binary_maximum_ttnn(input_shapes, device): assert comp_pass +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "scalar", + {random.randint(-100, 100) + 0.5 for _ in range(5)}, +) +def test_binary_maximum_scalar_ttnn(input_shapes, scalar, device): + in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) + + output_tensor = ttnn.maximum(input_tensor1, scalar) + golden_function = ttnn.get_golden_function(ttnn.maximum) + golden_tensor = golden_function(in_data1, torch.tensor(scalar)) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass + + @pytest.mark.parametrize( "input_shapes", ( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index 6da21e63a87b..90b7721cf068 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -215,6 +215,20 @@ struct ExecuteGCD { const std::optional& memory_config = std::nullopt); }; +struct ExecuteMaximum +{ + static Tensor invoke( + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& memory_config = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + float scalar, + const std::optional& memory_config = std::nullopt); + +}; + } // namespace binary } // namespace operations @@ -229,7 +243,7 @@ constexpr auto minimum = ttnn::register_operation_with_auto_launch_op< operations::binary::ExecuteBinaryCompositeOps>(); constexpr auto maximum = ttnn::register_operation_with_auto_launch_op< "ttnn::maximum", - operations::binary::ExecuteBinaryCompositeOps>(); + operations::binary::ExecuteMaximum>(); constexpr auto atan2 = ttnn::register_operation_with_auto_launch_op< "ttnn::atan2", operations::binary::ExecuteBinaryCompositeOps>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 919b624b39bd..08239bc5e818 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -332,7 +332,7 @@ void bind_binary_composite_with_rtol_atol(py::module& module, const binary_opera } template -void bind_div_like_ops(py::module& module, const binary_operation_t& operation, const std::string& description) { +void bind_binary_composite_overload(py::module& module, const binary_operation_t& operation, const std::string& description) { auto doc = fmt::format( R"doc( {2} @@ -963,13 +963,6 @@ void py_module(py::module& module) { R"doc(\mathrm{output\_tensor}_i = \text{min}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right) )doc"); - detail::bind_binary_composite( - module, - ttnn::maximum, - R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", - R"doc(\mathrm{output\_tensor}_i = \text{max}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right) - )doc"); - detail::bind_binary_composite( module, ttnn::atan2, @@ -1080,16 +1073,21 @@ void py_module(py::module& module) { R"doc(\mathrm{output}_i = \begin{cases} \mathrm{\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)}, & \text{if } \mathrm{round\_mode} = \mathrm{None} \\ \mathrm{\text{floor}\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)}, & \text{if } \mathrm{round\_mode} = \mathrm{floor} \\ \mathrm{\text{trunc}\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)}, & \text{if } \mathrm{round\_mode} = \mathrm{trunc} \end{cases} )doc"); - detail::bind_div_like_ops( + detail::bind_binary_composite_overload( module, ttnn::div_no_nan, R"doc(Compute div_no_nan :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); - detail::bind_div_like_ops( + detail::bind_binary_composite_overload( module, ttnn::floor_div, R"doc(Compute floor division :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + detail::bind_binary_composite_overload( + module, + ttnn::maximum, + R"doc(Compute maximum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + detail::bind_binary_composite( module, ttnn::scatter, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index ef74e088e1a0..b6193bd3bf30 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -113,12 +113,18 @@ Tensor _minimum(const Tensor& input_a, const Tensor& input_b, const std::optiona } // maximum(a,b) = a + (b - a > 0 )*(b-a) -Tensor _maximum(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { +Tensor ExecuteMaximum::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { Tensor t_diff = ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config); Tensor result = ttnn::where(t_diff, input_b, input_a); return result; } +Tensor ExecuteMaximum::invoke(const Tensor& input_a, float value, const std::optional& output_mem_config) { + Tensor t_diff = ttnn::rsub(input_a, value, output_mem_config); + Tensor result = ttnn::where(t_diff, value, input_a); + return result; +} + Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { Tensor result(input_a); { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp index 338be5c1a9ed..deded9bdbf4b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp @@ -21,7 +21,6 @@ enum class BinaryCompositeOpType { NEXTAFTER, ISCLOSE, MINIMUM, - MAXIMUM, ATAN2, DIV_NO_NAN, FLOOR_DIV, @@ -33,7 +32,6 @@ enum class BinaryCompositeOpType { Tensor _hypot(const Tensor&, const Tensor&, const std::optional&); Tensor _xlogy(const Tensor&, const Tensor&, const std::optional&); Tensor _minimum(const Tensor&, const Tensor&, const std::optional&); -Tensor _maximum(const Tensor&, const Tensor&, const std::optional&); Tensor _atan2(const Tensor&, const Tensor&, const std::optional&); Tensor _nextafter(const Tensor&, const Tensor&, const std::optional&); Tensor _addalpha(const Tensor&, const Tensor&, float, const std::optional&); @@ -80,13 +78,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional& mem_cfg) { - return _maximum(t1, t2, mem_cfg); - } -}; - template <> struct OpHandler { static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional& mem_cfg) {