Skip to content

Commit

Permalink
#13864: Add forward support for minimum op with tensor, scalar variant
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Oct 25, 2024
1 parent d4ff5e9 commit 9ac8f20
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 17 deletions.
27 changes: 25 additions & 2 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ def test_binary_minimum_ttnn(input_shapes, device):
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"scalar",
{-82.5, -45.7, 0.0, 12.5, 66.4, 96, 8},
)
def test_binary_minimum_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)

output_tensor = ttnn.minimum(input_tensor1, scalar)
golden_function = ttnn.get_golden_function(ttnn.minimum)
golden_tensor = golden_function(in_data1, torch.full(input_shapes, scalar))

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down Expand Up @@ -149,14 +172,14 @@ def test_binary_maximum_ttnn(input_shapes, device):
)
@pytest.mark.parametrize(
"scalar",
{random.randint(-100, 100) + 0.5 for _ in range(5)},
{-82.5, -45.7, 0.0, 12.5, 66.4, 96, 8},
)
def test_binary_maximum_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)

output_tensor = ttnn.maximum(input_tensor1, scalar)
golden_function = ttnn.get_golden_function(ttnn.maximum)
golden_tensor = golden_function(in_data1, torch.tensor(scalar))
golden_tensor = golden_function(in_data1, torch.full(input_shapes, scalar))

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Expand Down
16 changes: 15 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ struct ExecuteMaximum

};

struct ExecuteMinimum
{
static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

};

} // namespace binary
} // namespace operations

Expand All @@ -240,7 +254,7 @@ constexpr auto xlogy = ttnn::register_operation_with_auto_launch_op<
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::XLOGY>>();
constexpr auto minimum = ttnn::register_operation_with_auto_launch_op<
"ttnn::minimum",
operations::binary::ExecuteBinaryCompositeOps<operations::binary::BinaryCompositeOpType::MINIMUM>>();
operations::binary::ExecuteMinimum>();
constexpr auto maximum = ttnn::register_operation_with_auto_launch_op<
"ttnn::maximum",
operations::binary::ExecuteMaximum>();
Expand Down
6 changes: 2 additions & 4 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,12 +955,10 @@ void py_module(py::module& module) {
R"doc(\mathrm{output\_tensor}_i = \begin{cases} \mathrm{next\_float}(\mathrm{input\_tensor\_a}_i, \mathrm{input\_tensor\_b}_i), & \text{if } \mathrm{input\_tensor\_a}_i \neq \mathrm{input\_tensor\_b}_i \\ \mathrm{input\_tensor\_a}_i, & \text{if } \mathrm{input\_tensor\_a}_i = \mathrm{input\_tensor\_b}_i \end{cases}
)doc");

detail::bind_binary_composite(
detail::bind_binary_composite_overload(
module,
ttnn::minimum,
R"doc(Compute minimum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(\mathrm{output\_tensor}_i = \text{min}\left(\mathrm{input\_tensor\_a}_i , \mathrm{input\_tensor\_b}_i\right)
)doc");
R"doc(Compute minimum :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");

detail::bind_binary_composite(
module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,18 @@ Tensor _isclose(
}

// minimum(a,b) = a - (a - b > 0 )*(a-b)
Tensor _minimum(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor ExecuteMinimum::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor t_diff = ttnn::subtract(input_a, input_b, std::nullopt, output_mem_config);
Tensor result = ttnn::where(t_diff, input_b, input_a);
return result;
}

Tensor ExecuteMinimum::invoke(const Tensor& input_a, float value, const std::optional<MemoryConfig>& output_mem_config) {
Tensor t_diff = ttnn::subtract(input_a, value, std::nullopt, output_mem_config);
Tensor result = ttnn::where(t_diff, value, input_a);
return result;
}

// maximum(a,b) = a + (b - a > 0 )*(b-a)
Tensor ExecuteMaximum::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor t_diff = ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ enum class BinaryCompositeOpType {
SUBALPHA,
NEXTAFTER,
ISCLOSE,
MINIMUM,
ATAN2,
LOGICAL_XOR,
DIV_NO_NAN,
Expand All @@ -33,7 +32,6 @@ enum class BinaryCompositeOpType {

Tensor _hypot(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _xlogy(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _minimum(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _atan2(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _logical_xor(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Tensor _nextafter(const Tensor&, const Tensor&, const std::optional<MemoryConfig>&);
Expand Down Expand Up @@ -75,13 +73,6 @@ struct OpHandler<BinaryCompositeOpType::NEXTAFTER> {
}
};

template <>
struct OpHandler<BinaryCompositeOpType::MINIMUM> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional<MemoryConfig>& mem_cfg) {
return _minimum(t1, t2, mem_cfg);
}
};

template <>
struct OpHandler<BinaryCompositeOpType::ATAN2> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const std::optional<MemoryConfig>& mem_cfg) {
Expand Down

0 comments on commit 9ac8f20

Please sign in to comment.