Skip to content

Commit

Permalink
#13204: adjust matmul program config selection for some sharded outpu…
Browse files Browse the repository at this point in the history
…t scenarios
  • Loading branch information
bbradelTT committed Oct 15, 2024
1 parent aee03c7 commit d90a929
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config,
const CoreCoord& compute_with_storage_grid_size) {
const CoreCoord& compute_with_storage_grid_size,
const MemoryConfig& mem_config) {
const auto &ashape = input_tensor_a.get_legacy_shape(), bshape = input_tensor_b.get_legacy_shape();
uint32_t batch_size_a = get_batch_size(ashape);
uint32_t num_output_tiles = batch_size_a * ashape[-2] * bshape[-1] / TILE_HW; // Output M x N
Expand Down Expand Up @@ -362,9 +363,15 @@ inline MatmulProgramConfig create_simple_matmul_program_config(
num_blocks_y = (Mt - 1) / per_core_M + 1;
num_blocks_x = (Nt - 1) / per_core_N + 1;

// MatmulMultiCoreProgramConfig does not support sharded output.
// Reduce in0_block_w if necessary to choose other configs.
if (mem_config.is_sharded() and Kt % in0_block_w != 0) {
in0_block_w = 1;
}

if (num_blocks_x * num_blocks_y <= num_cores_x * num_cores_y and Kt % in0_block_w == 0) {
CoreCoord core_range = get_core_range(num_blocks_y, num_blocks_x, num_cores_y, num_cores_x);
if (core_range.y == 1) {
if (core_range.y == 1 or (mem_config.is_sharded() and core_range.y == 0)) {
return get_mcast_1d_config(
input_tensor_a,
input_tensor_b,
Expand Down Expand Up @@ -420,7 +427,8 @@ MatmulProgramConfig create_matmul_program_config(
const Tensor& input_tensor_b,
const std::optional<const CoreCoord> user_core_coord,
const std::optional<UnaryWithParam> fused_activation,
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config) {
const std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config,
const MemoryConfig& mem_config) {
auto a_shape = input_tensor_a.get_shape();
auto b_shape = input_tensor_b.get_shape();
auto a_padded_shape = a_shape.with_tile_padding();
Expand Down Expand Up @@ -467,7 +475,7 @@ MatmulProgramConfig create_matmul_program_config(
if (!can_cbs_fit_in_l1(
input_tensor_a, input_tensor_b, m_tiles_per_core, n_tiles_per_core, k_tiles_per_core)) {
return create_simple_matmul_program_config(
input_tensor_a, input_tensor_b, compute_kernel_config, core_coord);
input_tensor_a, input_tensor_b, compute_kernel_config, core_coord, mem_config);
}
} else if (a_is_sharded) {
TT_FATAL(
Expand Down Expand Up @@ -726,7 +734,7 @@ MatmulProgramConfig get_matmul_program_config(
};
}
return create_matmul_program_config(
input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config);
input_tensor_a, input_tensor_b, user_core_coord, fused_activation, compute_kernel_config, output_mem_config);
}

inline MatmulProgramConfig generate_matmul_program_config(
Expand All @@ -743,12 +751,12 @@ inline MatmulProgramConfig generate_matmul_program_config(
if (has_user_grid) {
core_coord = user_core_coord.value();
return create_matmul_program_config(
input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config);
input_tensor_a, input_tensor_b, user_core_coord, user_fused_activation, compute_kernel_config, mem_config);
} else {
tt::tt_metal::Device* device = input_tensor_a.device();
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
return create_simple_matmul_program_config(
input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size);
input_tensor_a, input_tensor_b, compute_kernel_config, compute_with_storage_grid_size, mem_config);
}
} else {
bool bmm = user_run_batched;
Expand Down

0 comments on commit d90a929

Please sign in to comment.