Skip to content

Commit

Permalink
#13646: Add unit test for channel bcast using repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Oct 21, 2024
1 parent dd5dd8d commit 9165ab0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
27 changes: 27 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_mul_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from torch.nn import functional as F


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_mul_channel_bcast_repeat(device, h, w):
torch_input_tensor_a = torch.rand((16, 16, h, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((16, 1, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.mul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.mul(input_tensor_a, input_tensor_b)
output = ttnn.to_torch(output)

assert_with_pcc(torch_output_tensor, output, 0.9999)
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,11 @@ auto preprocess_inputs(

// repeats second if it is smaller
if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[0] > second_shape[0]) {
tt::log_warning(tt::LogOp, "Using repeat op to broadcast batch dim");
Shape repeats(std::array<uint32_t, 4>{first_shape[0], 1, 1, 1});
second = ttnn::repeat(second, repeats);
}
// repeats second if it is smaller
if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[1] > second_shape[1]) {
tt::log_warning(tt::LogOp, "Using repeat op to broadcast channel dim");
Shape repeats(std::array<uint32_t, 4>{1, first_shape[1], 1, 1});
second = ttnn::repeat(second, repeats);
}
Expand Down

0 comments on commit 9165ab0

Please sign in to comment.