diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index 49980f2dc428..65900855b4bf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -117,6 +117,12 @@ auto preprocess_inputs( Shape repeats(std::array{first_shape[0], 1, 1, 1}); second = ttnn::repeat(second, repeats); } + // repeats second if it is smaller + if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[1] > second_shape[1]) { + tt::log_warning(tt::LogOp, "Using repeat op to broadcast channel dim"); + Shape repeats(std::array{1, first_shape[1], 1, 1}); + second = ttnn::repeat(second, repeats); + } }; repeat_smaller(input_tensor_a, input_tensor_b); repeat_smaller(input_tensor_b, input_tensor_a);