diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index f0aad38592b..e7b6652bdc6 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -1373,3 +1373,57 @@ def test_alternating_dst_sync_mode_matmul(device, M, K, N): assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) output_tensor = ttnn.to_torch(output3) assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + +def test_interleaved_input_sharded_output_matmul(device): + torch.manual_seed(0) + pcc = 0.99 + # Width sharded + torch_input_tensor_a = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16) + torch_input_tensor_b = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + + out_mem_config = ttnn.create_sharded_memory_config( + shape=(32, 256), + core_grid=ttnn.CoreGrid(x=1, y=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + output1 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output1) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + # Block sharded + out_mem_config = ttnn.create_sharded_memory_config( + shape=(32, 256), + core_grid=ttnn.CoreGrid(x=1, y=8), + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + output2 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output2) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + # Height sharded + torch_input_tensor_a = torch.randn([1, 1, 256, 32], dtype=torch.bfloat16) + torch_input_tensor_b = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + + out_mem_config = ttnn.create_sharded_memory_config( + shape=(256, 32), + core_grid=ttnn.CoreGrid(x=8, y=1), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output3) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 467f903bea6..9b667bd401a 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -333,7 +333,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config( const Tensor& input_tensor_a, const Tensor& input_tensor_b, const std::optional compute_kernel_config, - const CoreCoord& compute_with_storage_grid_size) { + const CoreCoord& compute_with_storage_grid_size, + const MemoryConfig& mem_config) { const auto &ashape = input_tensor_a.get_legacy_shape(), bshape = input_tensor_b.get_legacy_shape(); uint32_t batch_size_a = get_batch_size(ashape); uint32_t num_output_tiles = batch_size_a * ashape[-2] * bshape[-1] / TILE_HW; // Output M x N @@ -362,9 +363,16 @@ inline MatmulProgramConfig create_simple_matmul_program_config( num_blocks_y = (Mt - 1) / per_core_M + 1; num_blocks_x = (Nt - 1) / per_core_N + 1; + // MatmulMultiCoreProgramConfig does not support sharded output. + // Reduce in0_block_w if necessary to choose other configs. + if (mem_config.is_sharded() and Kt % in0_block_w != 0) { + in0_block_w = 1; + } + if (num_blocks_x * num_blocks_y <= num_cores_x * num_cores_y and Kt % in0_block_w == 0) { CoreCoord core_range = get_core_range(num_blocks_y, num_blocks_x, num_cores_y, num_cores_x); - if (core_range.y == 1) { + bool use_mcast_config = mem_config.is_sharded() and core_range.y == 0; + if (core_range.y == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) { return get_mcast_1d_config( input_tensor_a, input_tensor_b, @@ -374,7 +382,7 @@ inline MatmulProgramConfig create_simple_matmul_program_config( false /* out_sharded */, std::nullopt /* compute_with_storage_grid_size */, compute_kernel_config); - } else if (core_range.x == 1) { + } else if (core_range.x == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { return get_mcast_1d_config( input_tensor_a, input_tensor_b, @@ -384,7 +392,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config( false /* out_sharded */, std::nullopt /* compute_with_storage_grid_size */, compute_kernel_config); - } else if (core_range.y > 0 && num_blocks_x <= num_cores_x && num_blocks_y <= num_cores_y) { + } else if ((core_range.y > 0 and num_blocks_x <= num_cores_x and num_blocks_y <= num_cores_y) or + (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED)) { bool transpose_mcast = input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED && input_tensor_a.shard_spec().value().orientation == ShardOrientation::COL_MAJOR; out_subblock_h = 4; @@ -420,7 +429,8 @@ MatmulProgramConfig create_matmul_program_config( const Tensor& input_tensor_b, const std::optional user_core_coord, const std::optional fused_activation, - const std::optional compute_kernel_config) { + const std::optional compute_kernel_config, + const MemoryConfig& mem_config) { auto a_shape = input_tensor_a.get_shape(); auto b_shape = input_tensor_b.get_shape(); auto a_padded_shape = a_shape.with_tile_padding(); @@ -467,7 +477,7 @@ MatmulProgramConfig create_matmul_program_config( if (!can_cbs_fit_in_l1( input_tensor_a, input_tensor_b, m_tiles_per_core, n_tiles_per_core, k_tiles_per_core)) { return create_simple_matmul_program_config( - input_tensor_a, input_tensor_b, compute_kernel_config, core_coord); + input_tensor_a, input_tensor_b, compute_kernel_config, core_coord, mem_config); } } else if (a_is_sharded) { TT_FATAL( @@ -726,7 +736,7 @@ MatmulProgramConfig get_matmul_program_config( }; } return create_matmul_program_config( - input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config); + input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config, output_mem_config); } inline MatmulProgramConfig generate_matmul_program_config( @@ -743,12 +753,12 @@ inline MatmulProgramConfig generate_matmul_program_config( if (has_user_grid) { core_coord = user_core_coord.value(); return create_matmul_program_config( - input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config); + input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config, mem_config); } else { tt::tt_metal::Device* device = input_tensor_a.device(); auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); return create_simple_matmul_program_config( - input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size); + input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size, mem_config); } } else { bool bmm = user_run_batched;