Skip to content

Commit

Permalink
#13204: adjust matmul program config selection for some height/block …
Browse files Browse the repository at this point in the history
…sharded output scenarios
  • Loading branch information
bbradelTT committed Oct 15, 2024
1 parent d90a929 commit e97f09e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(

if (num_blocks_x * num_blocks_y <= num_cores_x * num_cores_y and Kt % in0_block_w == 0) {
CoreCoord core_range = get_core_range(num_blocks_y, num_blocks_x, num_cores_y, num_cores_x);
if (core_range.y == 1 or (mem_config.is_sharded() and core_range.y == 0)) {
bool use_mcast_config = mem_config.is_sharded() and core_range.y == 0;
if (core_range.y == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) {
return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand All @@ -381,7 +382,7 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
false /* out_sharded */,
std::nullopt /* compute_with_storage_grid_size */,
compute_kernel_config);
} else if (core_range.x == 1) {
} else if (core_range.x == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand All @@ -391,7 +392,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
false /* out_sharded */,
std::nullopt /* compute_with_storage_grid_size */,
compute_kernel_config);
} else if (core_range.y > 0 && num_blocks_x <= num_cores_x && num_blocks_y <= num_cores_y) {
} else if ((core_range.y > 0 and num_blocks_x <= num_cores_x and num_blocks_y <= num_cores_y) or
(use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED)) {
bool transpose_mcast = input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED &&
input_tensor_a.shard_spec().value().orientation == ShardOrientation::COL_MAJOR;
out_subblock_h = 4;
Expand Down

0 comments on commit e97f09e

Please sign in to comment.