Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#13204: adjust matmul program config selection for some sharded output scenarios #13819

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,3 +1373,57 @@ def test_alternating_dst_sync_mode_matmul(device, M, K, N):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)


def test_interleaved_input_sharded_output_matmul(device):
torch.manual_seed(0)
pcc = 0.99
# Width sharded
torch_input_tensor_a = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output1 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output1)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

# Block sharded
out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output2 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output2)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

# Height sharded
torch_input_tensor_a = torch.randn([1, 1, 256, 32], dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

out_mem_config = ttnn.create_sharded_memory_config(
shape=(256, 32),
core_grid=ttnn.CoreGrid(x=8, y=1),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
28 changes: 19 additions & 9 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config,
const CoreCoord& compute_with_storage_grid_size) {
const CoreCoord& compute_with_storage_grid_size,
const MemoryConfig& mem_config) {
const auto &ashape = input_tensor_a.get_legacy_shape(), bshape = input_tensor_b.get_legacy_shape();
uint32_t batch_size_a = get_batch_size(ashape);
uint32_t num_output_tiles = batch_size_a * ashape[-2] * bshape[-1] / TILE_HW; // Output M x N
Expand Down Expand Up @@ -362,9 +363,16 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
num_blocks_y = (Mt - 1) / per_core_M + 1;
num_blocks_x = (Nt - 1) / per_core_N + 1;

// MatmulMultiCoreProgramConfig does not support sharded output.
// Reduce in0_block_w if necessary to choose other configs.
if (mem_config.is_sharded() and Kt % in0_block_w != 0) {
in0_block_w = 1;
}

if (num_blocks_x * num_blocks_y <= num_cores_x * num_cores_y and Kt % in0_block_w == 0) {
CoreCoord core_range = get_core_range(num_blocks_y, num_blocks_x, num_cores_y, num_cores_x);
if (core_range.y == 1) {
bool use_mcast_config = mem_config.is_sharded() and core_range.y == 0;
if (core_range.y == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) {
Comment on lines +366 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to do something like this for in0_block_w:

uint32_t in0_block_w = (Kt % 2 == 0) ? 2 : 1;

and remove the code from lines 366-371 and the check for Kt % in0_block_w == 0 on line 372?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would. Unfortunately, that would require updating the PCCs for models/tests. I'm hoping to refactor this code, and will look at it then.

return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand All @@ -374,7 +382,7 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
false /* out_sharded */,
std::nullopt /* compute_with_storage_grid_size */,
compute_kernel_config);
} else if (core_range.x == 1) {
} else if (core_range.x == 1 or (use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand All @@ -384,7 +392,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
false /* out_sharded */,
std::nullopt /* compute_with_storage_grid_size */,
compute_kernel_config);
} else if (core_range.y > 0 && num_blocks_x <= num_cores_x && num_blocks_y <= num_cores_y) {
} else if ((core_range.y > 0 and num_blocks_x <= num_cores_x and num_blocks_y <= num_cores_y) or
(use_mcast_config and mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED)) {
bool transpose_mcast = input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED &&
input_tensor_a.shard_spec().value().orientation == ShardOrientation::COL_MAJOR;
out_subblock_h = 4;
Expand Down Expand Up @@ -420,7 +429,8 @@ MatmulProgramConfig create_matmul_program_config(
const Tensor& input_tensor_b,
const std::optional<const CoreCoord> user_core_coord,
const std::optional<UnaryWithParam> fused_activation,
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config) {
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config,
const MemoryConfig& mem_config) {
auto a_shape = input_tensor_a.get_shape();
auto b_shape = input_tensor_b.get_shape();
auto a_padded_shape = a_shape.with_tile_padding();
Expand Down Expand Up @@ -467,7 +477,7 @@ MatmulProgramConfig create_matmul_program_config(
if (!can_cbs_fit_in_l1(
input_tensor_a, input_tensor_b, m_tiles_per_core, n_tiles_per_core, k_tiles_per_core)) {
return create_simple_matmul_program_config(
input_tensor_a, input_tensor_b, compute_kernel_config, core_coord);
input_tensor_a, input_tensor_b, compute_kernel_config, core_coord, mem_config);
}
} else if (a_is_sharded) {
TT_FATAL(
Expand Down Expand Up @@ -726,7 +736,7 @@ MatmulProgramConfig get_matmul_program_config(
};
}
return create_matmul_program_config(
input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config);
input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config, output_mem_config);
}

inline MatmulProgramConfig generate_matmul_program_config(
Expand All @@ -743,12 +753,12 @@ inline MatmulProgramConfig generate_matmul_program_config(
if (has_user_grid) {
core_coord = user_core_coord.value();
return create_matmul_program_config(
input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config);
input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config, mem_config);
} else {
tt::tt_metal::Device* device = input_tensor_a.device();
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
return create_simple_matmul_program_config(
input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size);
input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size, mem_config);
}
} else {
bool bmm = user_run_batched;
Expand Down
Loading