Skip to content

Commit

Permalink
tenstorrent#13204: adjust matmul program config selection for some sh…
Browse files Browse the repository at this point in the history
…arded output scenarios (tenstorrent#13819)

* tenstorrent#13204: adjust matmul program config selection for some sharded output scenarios

* tenstorrent#13204: adjust matmul program config selection for some height/block sharded output scenarios

* tenstorrent#13204: add interleaved input sharded output matmul test
  • Loading branch information
bbradelTT authored and Christopher Taylor committed Nov 9, 2024
1 parent 4166e17 commit c99d497
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
54 changes: 54 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,3 +1373,57 @@ def test_alternating_dst_sync_mode_matmul(device, M, K, N):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)


def test_interleaved_input_sharded_output_matmul(device):
torch.manual_seed(0)
pcc = 0.99
# Width sharded
torch_input_tensor_a = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output1 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output1)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

# Block sharded
out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output2 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output2)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

# Height sharded
torch_input_tensor_a = torch.randn([1, 1, 256, 32], dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

out_mem_config = ttnn.create_sharded_memory_config(
shape=(256, 32),
core_grid=ttnn.CoreGrid(x=8, y=1),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
28 changes: 19 additions & 9 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config,
const CoreCoord& compute_with_storage_grid_size) {
const CoreCoord& compute_with_storage_grid_size,
const MemoryConfig& mem_config) {
const auto &ashape = input_tensor_a.get_legacy_shape(), bshape = input_tensor_b.get_legacy_shape();
uint32_t batch_size_a = get_batch_size(ashape);
uint32_t num_output_tiles = batch_size_a * ashape[-2] * bshape[-1] / TILE_HW; // Output M x N
Expand Down Expand Up @@ -362,9 +363,16 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
num_blocks_y = (Mt - 1) / per_core_M + 1;
num_blocks_x = (Nt - 1) / per_core_N + 1;

// MatmulMultiCoreProgramConfig does not support sharded output.
// Reduce in0_block_w if necessary to choose other configs.
if (mem_config.is_sharded() and Kt % in0_block_w != 0) {
in0_block_w = 1;
}

if (num_blocks_x * num_blocks_y <= num_cores_x * num_cores_y and Kt % in0_block_w == 0) {
CoreCoord core_range = get_core_range(num_blocks_y, num_blocks_x, num_cores_y, num_cores_x);
if (core_range.y == 1) {
bool use_mcast_config = mem_config.is_sharded() and core_range.y == 0;
if (core_range.y == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) {
return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand All @@ -374,7 +382,7 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
false /* out_sharded */,
std::nullopt /* compute_with_storage_grid_size */,
compute_kernel_config);
} else if (core_range.x == 1) {
} else if (core_range.x == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand All @@ -384,7 +392,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
false /* out_sharded */,
std::nullopt /* compute_with_storage_grid_size */,
compute_kernel_config);
} else if (core_range.y > 0 && num_blocks_x <= num_cores_x && num_blocks_y <= num_cores_y) {
} else if ((core_range.y > 0 and num_blocks_x <= num_cores_x and num_blocks_y <= num_cores_y) or
(use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED)) {
bool transpose_mcast = input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED &&
input_tensor_a.shard_spec().value().orientation == ShardOrientation::COL_MAJOR;
out_subblock_h = 4;
Expand Down Expand Up @@ -420,7 +429,8 @@ MatmulProgramConfig create_matmul_program_config(
const Tensor& input_tensor_b,
const std::optional<const CoreCoord> user_core_coord,
const std::optional<UnaryWithParam> fused_activation,
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config) {
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config,
const MemoryConfig& mem_config) {
auto a_shape = input_tensor_a.get_shape();
auto b_shape = input_tensor_b.get_shape();
auto a_padded_shape = a_shape.with_tile_padding();
Expand Down Expand Up @@ -467,7 +477,7 @@ MatmulProgramConfig create_matmul_program_config(
if (!can_cbs_fit_in_l1(
input_tensor_a, input_tensor_b, m_tiles_per_core, n_tiles_per_core, k_tiles_per_core)) {
return create_simple_matmul_program_config(
input_tensor_a, input_tensor_b, compute_kernel_config, core_coord);
input_tensor_a, input_tensor_b, compute_kernel_config, core_coord, mem_config);
}
} else if (a_is_sharded) {
TT_FATAL(
Expand Down Expand Up @@ -726,7 +736,7 @@ MatmulProgramConfig get_matmul_program_config(
};
}
return create_matmul_program_config(
input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config);
input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config, output_mem_config);
}

inline MatmulProgramConfig generate_matmul_program_config(
Expand All @@ -743,12 +753,12 @@ inline MatmulProgramConfig generate_matmul_program_config(
if (has_user_grid) {
core_coord = user_core_coord.value();
return create_matmul_program_config(
input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config);
input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config, mem_config);
} else {
tt::tt_metal::Device* device = input_tensor_a.device();
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
return create_simple_matmul_program_config(
input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size);
input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size, mem_config);
}
} else {
bool bmm = user_run_batched;
Expand Down

0 comments on commit c99d497

Please sign in to comment.