diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_mul_channel.py b/tests/ttnn/unit_tests/operations/eltwise/test_mul_channel.py new file mode 100644 index 00000000000..73eb845aafa --- /dev/null +++ b/tests/ttnn/unit_tests/operations/eltwise/test_mul_channel.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import assert_with_pcc +from torch.nn import functional as F + + +@pytest.mark.parametrize("h", [32]) +@pytest.mark.parametrize("w", [64]) +def test_mul_channel_bcast_repeat(device, h, w): + torch_input_tensor_a = torch.rand((16, 16, h, w), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((16, 1, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch.mul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + output = ttnn.mul(input_tensor_a, input_tensor_b) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output_tensor, output, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index 65900855b4b..832ba616365 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -113,13 +113,11 @@ auto preprocess_inputs( // repeats second if it is smaller if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[0] > second_shape[0]) { - tt::log_warning(tt::LogOp, "Using repeat op to broadcast batch dim"); Shape repeats(std::array{first_shape[0], 1, 1, 1}); second = ttnn::repeat(second, repeats); } // repeats second if it is smaller if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[1] > second_shape[1]) { - tt::log_warning(tt::LogOp, "Using repeat op to broadcast channel dim"); Shape repeats(std::array{1, first_shape[1], 1, 1}); second = ttnn::repeat(second, repeats); }