diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index c17c0018039c..9b667bd401a3 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -371,7 +371,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config( if (num_blocks_x * num_blocks_y <= num_cores_x * num_cores_y and Kt % in0_block_w == 0) { CoreCoord core_range = get_core_range(num_blocks_y, num_blocks_x, num_cores_y, num_cores_x); - if (core_range.y == 1 or (mem_config.is_sharded() and core_range.y == 0)) { + bool use_mcast_config = mem_config.is_sharded() and core_range.y == 0; + if (core_range.y == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) { return get_mcast_1d_config( input_tensor_a, input_tensor_b, @@ -381,7 +382,7 @@ inline MatmulProgramConfig create_simple_matmul_program_config( false /* out_sharded */, std::nullopt /* compute_with_storage_grid_size */, compute_kernel_config); - } else if (core_range.x == 1) { + } else if (core_range.x == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { return get_mcast_1d_config( input_tensor_a, input_tensor_b, @@ -391,7 +392,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config( false /* out_sharded */, std::nullopt /* compute_with_storage_grid_size */, compute_kernel_config); - } else if (core_range.y > 0 && num_blocks_x <= num_cores_x && num_blocks_y <= num_cores_y) { + } else if ((core_range.y > 0 and num_blocks_x <= num_cores_x and num_blocks_y <= num_cores_y) or + (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED)) { bool transpose_mcast = input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED && input_tensor_a.shard_spec().value().orientation == ShardOrientation::COL_MAJOR; out_subblock_h = 4;