diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py index a1c0496f211..cfe555d0367 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py @@ -4,10 +4,12 @@ import ttnn import torch +from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_model_utils import get_conv_input_memory_config from models.utility_functions import ( is_grayskull, is_wormhole_b0, pad_and_fold_conv_activation_for_unity_stride, + nearest_y, ) from typing import List @@ -391,8 +393,25 @@ def __init__( self.conv1_bias_tensor = parameters.conv1.bias self.conv1_input_channels = self.conv1_weight_tensor.shape[1] self.conv1_output_channels = self.conv1_weight_tensor.shape[0] + self.conv1_input_height = 259 + self.conv1_input_width = 259 + self.conv1_output_height = ttnn.get_conv_output_dim(self.conv1_input_height, 4, 1, 0) + self.conv1_output_width = ttnn.get_conv_output_dim(self.conv1_input_width, 4, 1, 0) assert self.conv1_weight_tensor.shape[2] == 4 + self.grayskull_conv1_input_memory_config = get_conv_input_memory_config( + self.batch_size, + self.conv1_input_channels, + self.conv1_input_height, + self.conv1_input_width, + self.conv1_output_channels, + self.conv1_output_height, + self.conv1_output_width, + device.compute_with_storage_grid_size(), + 16, + True, + ) + self.layer1 = self._make_layer( parameters=parameters.layer1, planes=64, @@ -522,6 +541,11 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt else: act_block_h_override = 0 + if is_grayskull(): + input_tensor = ttnn.to_device( + input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config + ) + x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, @@ -533,8 +557,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt stride=(1, 1), padding=(0, 0), batch_size=self.batch_size, - input_height=259, - input_width=259, + input_height=self.conv1_input_height, + input_width=self.conv1_input_width, conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], @@ -828,6 +852,11 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c else: act_block_h_override = 0 + if is_grayskull(): + input_tensor = ttnn.to_device( + input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config + ) + x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, @@ -839,8 +868,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c stride=(1, 1), padding=(0, 0), batch_size=self.batch_size, - input_height=259, - input_width=259, + input_height=self.conv1_input_height, + input_width=self.conv1_input_width, conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py new file mode 100644 index 00000000000..eff32fdee1c --- /dev/null +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import math +import ttnn +from models.utility_functions import nearest_y + + +def get_core_grid_from_num_cores(num_cores: int, grid_rows: int, grid_cols: int): + columns = num_cores // grid_rows + assert columns <= grid_cols, "Not enough cores for specified core grid" + ranges = [] + if columns != 0: + ranges.append( + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(grid_rows - 1, columns - 1), + ) + ) + remainder = num_cores % grid_rows + if remainder != 0: + assert columns + 1 <= grid_cols, "Not enough cores for specified core grid" + ranges.append( + ttnn.CoreRange( + ttnn.CoreCoord(0, columns), + ttnn.CoreCoord(remainder - 1, columns), + ) + ) + return ttnn.CoreRangeSet({*ranges}) + + +def find_closest_largest_divisor(num: int, start_divisor: int) -> int: + divisor = start_divisor + while num % divisor != 0: + divisor -= 1 + return divisor + + +# Determins input memory config for a height sharded conv operation. +# If override_num_cores is set to True, the number of cores will be overriden to the closest largest divisor of the number of tiles +# This will avoid default conv codepath which can pad-up the nhw num tiles and produce padded output +# This can lead to issues with data-movment ops not handling padding correctly +def get_conv_input_memory_config( + batch_size: int, + input_channels: int, + input_height: int, + input_width: int, + output_channels: int, + output_height: int, + output_width: int, + compute_grid: ttnn.CoreGrid, + input_channels_alignment: int, + override_num_cores: bool, +) -> ttnn.MemoryConfig: + parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + batch_size=batch_size, + input_channels=input_channels, + output_height=output_height, + output_width=output_width, + output_channels=output_channels, + compute_grid_size=compute_grid, + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + enable_channels_padding=True, + is_out_tiled=True, + ) + + if override_num_cores: + nhw_ntiles = math.ceil(batch_size * output_height * output_width / 32) + num_cores_nwh = find_closest_largest_divisor(nhw_ntiles, compute_grid.x * compute_grid.y) + parallel_config.grid = get_core_grid_from_num_cores(num_cores_nwh, compute_grid.x, compute_grid.y) + + memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + input_width * input_height * batch_size, + nearest_y( + input_channels, + input_channels_alignment, + ), + ] + ), + parallel_config=parallel_config, + tile_size=32, + ) + return memory_config diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 3a5c75967e9..44d90cb0f34 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -4,11 +4,11 @@ import ttnn import torch +from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_model_utils import get_conv_input_memory_config from models.utility_functions import ( is_grayskull, is_wormhole_b0, _nearest_y, - pad_and_fold_conv_activation_for_unity_stride, ) from typing import List from loguru import logger @@ -632,15 +632,18 @@ def __init__( conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16) conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) - _, self.override_fold_mem_config, _, _ = ttnn.get_conv_padded_input_shape_and_mem_config( - device=device, - input_tensor=conv_dummy_tensor, - conv_config=self.conv1_config, - batch_size=self.batch_size, - height=self.conv1_output_height, - width=self.conv1_output_width, - in_channels=self.conv1_input_channels, - out_channels=self.conv1_output_channels, + + self.override_fold_mem_config = get_conv_input_memory_config( + self.batch_size, + self.conv1_input_channels, + self.conv1_input_height, + self.conv1_input_width, + self.conv1_output_channels, + self.conv1_output_height, + self.conv1_output_width, + device.compute_with_storage_grid_size(), + self.conv1_config.input_channels_alignment, + is_grayskull(), ) def __del__(self): diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index b6643d55d4a..f2e266e1d8b 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -4,6 +4,7 @@ import ttnn import torch +from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_model_utils import get_conv_input_memory_config from models.utility_functions import ( is_grayskull, is_wormhole_b0, @@ -388,8 +389,25 @@ def __init__( self.conv1_bias_tensor = parameters.conv1.bias self.conv1_input_channels = self.conv1_weight_tensor.shape[1] self.conv1_output_channels = self.conv1_weight_tensor.shape[0] + self.conv1_input_height = 451 + self.conv1_input_width = 451 + self.conv1_output_height = ttnn.get_conv_output_dim(self.conv1_input_height, 4, 1, 0) + self.conv1_output_width = ttnn.get_conv_output_dim(self.conv1_input_width, 4, 1, 0) assert self.conv1_weight_tensor.shape[2] == 4 + self.grayskull_conv1_input_memory_config = get_conv_input_memory_config( + self.batch_size, + self.conv1_input_channels, + self.conv1_input_height, + self.conv1_input_width, + self.conv1_output_channels, + self.conv1_output_height, + self.conv1_output_width, + device.compute_with_storage_grid_size(), + 16, + True, + ) + self.layer1 = self._make_layer( parameters=parameters.layer1, planes=64, @@ -518,6 +536,11 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt elif batch_size == 20: act_block_h_override = 640 + if is_grayskull(): + input_tensor = ttnn.to_device( + input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config + ) + x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, @@ -529,8 +552,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt stride=(1, 1), padding=(0, 0), batch_size=self.batch_size, - input_height=451, - input_width=451, + input_height=self.conv1_input_height, + input_width=self.conv1_input_width, conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 3635026d809..2ad02078d71 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -74,6 +74,33 @@ def __init__( self.output_height = ttnn.get_conv_output_dim(input_height, 3, self.stride, 1) self.output_width = ttnn.get_conv_output_dim(input_width, 3, self.stride, 1) + self.shard_layout = ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if self.in_channels < 320 else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ) + + self.input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.batch_size * self.input_height * self.input_width, + self.out_channels, + ] + ), + parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=self.shard_layout, + batch_size=self.batch_size, + input_channels=self.in_channels, + output_height=self.output_height, + output_width=self.output_width, + output_channels=self.out_channels, + compute_grid_size=self.device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + enable_channels_padding=False, + is_out_tiled=True, + ), + tile_size=32, + ) def __call__( self, @@ -104,13 +131,15 @@ def __call__( math_approx_mode_enabled=True, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if self.in_channels < 320 - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=self.shard_layout, input_channels_alignment=32, transpose_shards=False, - reshard_if_not_optimal=True, + reshard_if_not_optimal=False, ) + + if hidden_states.memory_config() != self.input_memory_config: + hidden_states = ttnn.to_memory_config(hidden_states, self.input_memory_config) + if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 4e63fc9b13c..cdcea705626 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -106,6 +106,33 @@ def __init__( self.conv1_input_width = input_width self.conv1_in_channels = split_input_channels self.conv1_out_channels = out_channels + self.conv1_output_height = ttnn.get_conv_output_dim(self.conv1_input_height, 3, 1, 1) + self.conv1_output_width = ttnn.get_conv_output_dim(self.conv1_input_width, 3, 1, 1) + self.conv1_shard_layout = ttnn.TensorMemoryLayout.BLOCK_SHARDED + + self.conv1_input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.batch_size * self.conv1_input_height * self.conv1_input_width, + self.conv1_in_channels, + ] + ), + parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=self.conv1_shard_layout, + batch_size=self.batch_size, + input_channels=self.conv1_in_channels, + output_height=self.conv1_output_height, + output_width=self.conv1_output_width, + output_channels=self.conv1_out_channels, + compute_grid_size=self.device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + enable_channels_padding=False, + is_out_tiled=True, + ), + tile_size=32, + ) for i in range(conv1_split_chunks): self.conv1s_weights.append(ttnn.from_torch(split_weight_tensors[i], ttnn.float32)) @@ -165,6 +192,29 @@ def __init__( self.conv2_in_channels = parameters.conv2.weight.shape[1] self.conv2_out_channels = parameters.conv2.weight.shape[0] # self.conv2_config_override = config_override[(out_channels, out_channels, input_height, input_width)] + self.conv2_input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.batch_size * self.conv2_input_height * self.conv2_input_width, + out_channels, + ] + ), + parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, + batch_size=self.batch_size, + input_channels=self.conv2_in_channels, + output_height=self.conv2_input_height, + output_width=self.conv2_input_width, + output_channels=self.conv2_out_channels, + compute_grid_size=self.device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + enable_channels_padding=False, + is_out_tiled=True, + ), + tile_size=32, + ) self.groups = 32 # if use_in_shortcut: @@ -402,12 +452,14 @@ def __call__( # hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states)) # hidden_states = self.conv1s[0](hidden_states) + hidden_states = ttnn.to_memory_config(hidden_states, self.conv1_input_memory_config) + conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, math_fidelity=ttnn.MathFidelity.LoFi, activation="", - shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=self.conv1_shard_layout, math_approx_mode_enabled=True, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, @@ -598,6 +650,7 @@ def __call__( # hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG) hidden_states = ttnn.sharded_to_interleaved(hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype) + hidden_states = ttnn.to_memory_config(hidden_states, self.conv2_input_memory_config) # hidden_states = self.conv2(hidden_states) conv_config = ttnn.Conv2dConfig( diff --git a/models/demos/wormhole/yolov4/test_yolov4_performant.py b/models/demos/wormhole/yolov4/test_yolov4_performant.py index 5bf15281a7a..c7fdd1de271 100644 --- a/models/demos/wormhole/yolov4/test_yolov4_performant.py +++ b/models/demos/wormhole/yolov4/test_yolov4_performant.py @@ -24,7 +24,7 @@ def test_run_yolov4_inference(device, use_program_cache, batch_size, act_dtype, @run_for_wormhole_b0() -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1622720}], indirect=True) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1636352}], indirect=True) @pytest.mark.parametrize( "batch_size, act_dtype, weight_dtype", ((1, ttnn.bfloat16, ttnn.bfloat16),), @@ -50,7 +50,7 @@ def test_run_yolov4_trace_inference( @run_for_wormhole_b0() @pytest.mark.parametrize( - "device_params", [{"l1_small_size": 24576, "trace_region_size": 1622720, "num_command_queues": 2}], indirect=True + "device_params", [{"l1_small_size": 24576, "trace_region_size": 1636352, "num_command_queues": 2}], indirect=True ) @pytest.mark.parametrize( "batch_size, act_dtype, weight_dtype", diff --git a/models/demos/yolov4/ttnn/downsample1.py b/models/demos/yolov4/ttnn/downsample1.py index cc2f2cff37f..9937457fa94 100644 --- a/models/demos/yolov4/ttnn/downsample1.py +++ b/models/demos/yolov4/ttnn/downsample1.py @@ -48,7 +48,7 @@ def __call__(self, device, input_tensor): output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) output_tensor_left = ttnn.to_layout(output_tensor_left, layout=ttnn.ROW_MAJOR_LAYOUT) output_sharded_memory_config = ttnn.create_sharded_memory_config( - [512, 128], + [output_tensor.memory_config().shard_spec.shape[0], 2 * output_tensor.memory_config().shard_spec.shape[1]], core_grid=output_tensor_left.memory_config().shard_spec.grid, strategy=ttnn.ShardStrategy.HEIGHT, use_height_and_width_as_shard_shape=True, diff --git a/models/experimental/functional_unet/tests/common.py b/models/experimental/functional_unet/tests/common.py index b05a562cb7f..accd4ffde44 100644 --- a/models/experimental/functional_unet/tests/common.py +++ b/models/experimental/functional_unet/tests/common.py @@ -7,7 +7,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc -UNET_FULL_MODEL_PCC = 0.99995 +UNET_FULL_MODEL_PCC = 0.999 def is_n300_with_eth_dispatch_cores(mesh_device) -> bool: diff --git a/models/experimental/functional_unet/tests/test_unet_perf.py b/models/experimental/functional_unet/tests/test_unet_perf.py index c4245e3d53c..2d35a23a188 100644 --- a/models/experimental/functional_unet/tests/test_unet_perf.py +++ b/models/experimental/functional_unet/tests/test_unet_perf.py @@ -34,7 +34,7 @@ @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, groups, expected_device_perf_fps", - ((1, 2, 975.0),), + ((1, 2, 1115.0),), ) def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float): command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]" @@ -44,7 +44,7 @@ def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: flo inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" post_processed_results = run_device_perf( - command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=total_batch + command, subdir="unet_shallow", num_iterations=3, cols=cols, batch_size=total_batch ) expected_perf_cols = {inference_time_key: expected_device_perf_fps} expected_results = check_device_perf( diff --git a/models/experimental/functional_unet/tt/model_preprocessing.py b/models/experimental/functional_unet/tt/model_preprocessing.py index edc5c83d54a..ff77e0083fa 100644 --- a/models/experimental/functional_unet/tt/model_preprocessing.py +++ b/models/experimental/functional_unet/tt/model_preprocessing.py @@ -111,7 +111,6 @@ def create_unet_model_parameters( parameters.c6_3["use_activation_double_buffer"] = True parameters.c6_3["input_channels_alignment"] = 16 - parameters.c7["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} parameters.c7["use_activation_double_buffer"] = True parameters.c7["use_split_reader"] = True parameters.c7["input_channels_alignment"] = 16 diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 86de66b4cd7..215399ea23b 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -88,6 +88,7 @@ def __init__( activation_dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, output_layout=ttnn.TILE_LAYOUT, + reshard_if_not_optimal=False, mesh_mapper=None, ): self.device = device @@ -116,7 +117,6 @@ def __init__( math_fidelity=ttnn.MathFidelity.LoFi, shard_layout=shard_layout, deallocate_activation=self.deallocate_activation, - fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, enable_act_double_buffer=( conv.use_activation_double_buffer if "use_activation_double_buffer" in conv else False @@ -126,6 +126,7 @@ def __init__( activation=activation, output_layout=output_layout, input_channels_alignment=conv.input_channels_alignment if "input_channels_alignment" in conv else 32, + reshard_if_not_optimal=reshard_if_not_optimal, ) config_override = conv.conv_blocking_and_parallelization_config_override if config_override and "act_block_h" in config_override: @@ -193,10 +194,11 @@ def __init__( pool, device, conv_cache={}, - should_reshard=False, mesh_mapper=None, ): - self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) + self.conv1 = UNetConv2D( + conv1, bn=bn1, device=device, cache=conv_cache, reshard_if_not_optimal=True, mesh_mapper=mesh_mapper + ) self.conv2 = UNetConv2D( conv2, bn=bn2, @@ -206,32 +208,6 @@ def __init__( ) self.pool1 = UNetMaxPool2D(pool, conv2.out_channels, device=device) - self.should_reshard = should_reshard - if self.should_reshard: - self.parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config( - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - batch_size=self.conv1.batch_size, - input_channels=self.conv1.in_channels, - output_height=self.conv2.input_height, - output_width=self.conv2.input_width, - output_channels=self.conv1.out_channels, - compute_grid_size=device.compute_with_storage_grid_size(), - block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, - is_out_tiled=True, - ) - self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( - tensor_shape=ttnn.Shape( - [ - 1, - 1, - self.conv1.input_width * self.conv1.input_height * self.conv1.batch_size, - nearest_32(self.conv1.in_channels), - ] - ), - parallel_config=self.parallel_config, - tile_size=32, - ) - def __call__(self, x): assert list(x.shape) == [ 1, @@ -239,12 +215,6 @@ def __call__(self, x): self.conv1.input_height * self.conv1.input_width * self.conv1.batch_size, x.shape[-1], # Channels can be padded ], f"Expected downblock input to flattened into [1, 1, BHW, C] but was {list(x.shape)}" - if self.should_reshard: - x = ttnn.to_memory_config( - x, - memory_config=self.sharded_memory_config, - ) - x = self.conv1(x) x = self.conv2(x) residual = x @@ -254,39 +224,22 @@ def __call__(self, x): class UNetUpblock: def __init__( - self, conv1, bn1, conv2, bn2, conv3, bn3, device, conv_cache={}, should_reshard=False, mesh_mapper=None + self, + conv1, + bn1, + conv2, + bn2, + conv3, + bn3, + device, + conv_cache={}, + mesh_mapper=None, ): self.device = device - self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper) + self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, reshard_if_not_optimal=True, mesh_mapper=mesh_mapper) self.conv2 = UNetConv2D(conv2, bn2, device, conv_cache, mesh_mapper=mesh_mapper) self.conv3 = UNetConv2D(conv3, bn3, device, conv_cache, mesh_mapper=mesh_mapper) - self.should_reshard = should_reshard - if self.should_reshard: - self.parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config( - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - batch_size=self.conv1.batch_size, - input_channels=self.conv1.in_channels, - output_height=self.conv2.input_height, - output_width=self.conv2.input_width, - output_channels=self.conv1.out_channels, - compute_grid_size=device.compute_with_storage_grid_size(), - block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, - is_out_tiled=True, - ) - self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( - tensor_shape=ttnn.Shape( - [ - 1, - 1, - self.conv1.input_width * self.conv1.input_height * self.conv1.batch_size, - nearest_32(self.conv1.in_channels), - ] - ), - parallel_config=self.parallel_config, - tile_size=32, - ) - def upsample(self, x): # Need to reshape into (B, H, W, C) to get correct output from ttnn.upsample x = ttnn.reshape( @@ -332,12 +285,6 @@ def __call__(self, x, residual): ttnn.deallocate(x) ttnn.deallocate(residual) - if self.should_reshard: - if y.is_sharded(): - y = ttnn.reshard(y, self.sharded_memory_config) - else: - y = ttnn.interleaved_to_sharded(y, self.sharded_memory_config) - y = self.conv1(y) y = self.conv2(y) y = self.conv3(y) @@ -357,7 +304,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p1, device, conv_cache=self.conv_cache, - should_reshard=False, mesh_mapper=mesh_mapper, ) self.downblock2 = UNetDownblock( @@ -368,7 +314,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p2, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.downblock3 = UNetDownblock( @@ -379,7 +324,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.downblock4 = UNetDownblock( @@ -390,37 +334,20 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p4, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) - self.bnc = UNetConv2D(parameters.bnc, parameters.bnb, device, cache=self.conv_cache, mesh_mapper=mesh_mapper) + self.bnc = UNetConv2D( + parameters.bnc, + parameters.bnb, + device, + cache=self.conv_cache, + reshard_if_not_optimal=True, + mesh_mapper=mesh_mapper, + ) self.bnc2 = UNetConv2D( parameters.bnc_2, parameters.bnb_2, device, cache=self.conv_cache, mesh_mapper=mesh_mapper ) - bnc_parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config( - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - batch_size=self.bnc.batch_size, - input_channels=self.bnc.in_channels, - output_height=self.bnc2.input_height, - output_width=self.bnc2.input_width, - output_channels=self.bnc.out_channels, - compute_grid_size=device.compute_with_storage_grid_size(), - block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, - is_out_tiled=True, - ) - self.bnc_sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( - tensor_shape=ttnn.Shape( - [ - 1, - 1, - self.bnc.input_width * self.bnc.input_height * self.bnc.batch_size, - self.bnc.in_channels, - ] - ), - parallel_config=bnc_parallel_config, - tile_size=32, - ) self.upblock1 = UNetUpblock( parameters.c5, @@ -431,7 +358,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b5_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.upblock2 = UNetUpblock( @@ -443,7 +369,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b6_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.upblock3 = UNetUpblock( @@ -455,7 +380,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b7_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.upblock4 = UNetUpblock( @@ -467,7 +391,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b8_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) @@ -485,6 +408,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, + enable_channels_padding=True, ) self.input_sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( tensor_shape=ttnn.Shape( @@ -502,13 +426,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: ) def bottleneck(self, x): - if x.is_sharded(): - x = ttnn.reshard(x, self.bnc_sharded_memory_config) - else: - x = ttnn.interleaved_to_sharded( - x, - self.bnc_sharded_memory_config, - ) x = self.bnc(x) return self.bnc2(x) diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index 0f3176775cd..5d6024474d5 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -454,30 +454,22 @@ def test_conv2d_localrun(device, input_spec): failing_parameters = [ # [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation] # Input is 32MB maps to MM 64 cores, we neeed to avoid sharding this tensor and use dram intrelaved directly with MM - [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6 - [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14 - [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15 - [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141 - [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170 - [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171 - [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 173 - [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 182 - [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 183 - [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 199 - [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 205 - [1, 336, 336, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], # 241 - [1, 336, 336, 48, 48, 5, 5, 1, 1, 2, 2, 336, False, 1], # 245 - [1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], # 247 - [1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], # 292 - [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 293 - [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], # 294 - [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], # 347 - [1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], # 348 - [1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], # 363 - [1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 728, False, 1], # 366 - [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], # 367 - [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 374 - [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 395 + [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 5 + [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 14 + [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 169 + [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 170 + [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 172 + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 181 + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 182 + [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 198 + [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 203 + [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 204 + [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 292 + [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], # 366 + [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 373 + [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 374 + [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 394 + [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 395 ] diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 6dab6291762..77b51a0a0bd 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -158,6 +158,7 @@ def run_max_pool( output_channels=in_c, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + enable_channels_padding=False, is_out_tiled=False, ) sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( @@ -744,6 +745,7 @@ def test_pool_core_nondivis( output_channels=in_c, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + enable_channels_padding=False, is_out_tiled=True, ) sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 4c530338c60..cfe4c0f143a 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -150,7 +150,7 @@ def run_conv( enable_subblock_padding=False, output_layout=output_layout, ) - if config_override and "act_block_h" in config_override: + if config_override and "act_block_h" in config_override and not auto_shard: conv_config.act_block_h_override = config_override["act_block_h"] if config_override and "act_block_w_div" in config_override: @@ -1520,9 +1520,9 @@ def test_sd_conv_wh( False, ), # fails. mismatch. It passes when input_channels=64. Probably an issue with padding when input_channels % 32 != 0. (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 22 * 32}, False), - (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 22 * 32}, False), - (2, 1, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 22 * 32}, False), + (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), + (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), + (2, 1, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), ), ) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index e1f5e53ccc9..5762f6cb10a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -9,7 +9,9 @@ #include #include "common/constants.hpp" +#include "common/math.hpp" #include "impl/buffers/buffer_constants.hpp" +#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/operations/pool/downsample/device/downsample_op.hpp" @@ -19,6 +21,7 @@ #include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" using namespace tt; namespace ttnn { @@ -28,25 +31,37 @@ using sliding_window::ParallelConfig; namespace conv2d { -uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { +static uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; while (num % divisor != 0) divisor = divisor - 1; return divisor; } -uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { +static uint32_t find_closest_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { + uint32_t divisor = start_divisor; + while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + return divisor; +} + +static uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; uint32_t padded_num = round_up(num, divisor); - while ((padded_num - num) >= (int)(padded_num / divisor)) { + while ((padded_num - num) >= padded_num / divisor) { divisor = divisor - 1; padded_num = round_up(num, divisor); } return divisor; } -uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { +static uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num1, uint32_t num2, uint32_t start_divisor) { uint32_t divisor = start_divisor; - while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + uint32_t padded_num1 = round_up(num1, divisor); + uint32_t padded_num2 = round_up(num2, divisor); + while ((padded_num1 - num1) >= (padded_num1 / divisor) || (padded_num2 - num2) >= (padded_num2 / divisor)) { + divisor = divisor - 1; + padded_num1 = round_up(num1, divisor); + padded_num2 = round_up(num2, divisor); + } return divisor; } @@ -84,35 +99,40 @@ ParallelConfig determine_parallel_config( uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, + bool enable_channels_padding, bool is_out_tiled) { uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1; uint32_t out_nhw_ntiles = tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; - uint32_t out_c_ntiles = tt::round_up(output_channels, effective_tile_width) / effective_tile_width; + uint32_t input_channles_ntiles = tt::div_up(input_channels, effective_tile_width); + uint32_t out_channels_ntiles = tt::div_up(output_channels, effective_tile_width); // calculate num_core_nhw and the grid uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; - uint32_t num_cores_nhw = 0; CoreRangeSet grid; if (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - num_cores_nhw = find_closest_largest_divisor(out_nhw_ntiles, max_num_cores); - if (num_cores_nhw < compute_grid_size.x && out_nhw_ntiles > compute_grid_size.x) { - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, compute_grid_size.x); - } + uint32_t num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, max_num_cores); grid = num_cores_to_corerangeset(num_cores_nhw, compute_grid_size, true); } else if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { uint32_t start_divisor = block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); - uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x); + uint32_t num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); + uint32_t start_divisor_c = + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x; + uint32_t num_cores_c = + enable_channels_padding + ? find_closest_largest_divisor_with_num_padding( + out_channels_ntiles, input_channles_ntiles, start_divisor_c) + : find_closest_largest_divisor(out_channels_ntiles, input_channles_ntiles, start_divisor_c); uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); grid = CoreRangeSet({core_range}); } else if (shard_layout == TensorMemoryLayout::WIDTH_SHARDED) { - num_cores_nhw = 1; - uint32_t num_cores_c = find_closest_largest_divisor(std::ceil((float)input_channels / effective_tile_width), max_num_cores); + uint32_t num_cores_c = enable_channels_padding + ? find_closest_largest_divisor_with_num_padding(input_channles_ntiles, max_num_cores) + : find_closest_largest_divisor(input_channles_ntiles, max_num_cores); grid = num_cores_to_corerangeset(num_cores_c, compute_grid_size, true); } else { TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); @@ -127,6 +147,26 @@ ParallelConfig determine_parallel_config( return pconfig; } +static ParallelConfig determine_output_parallel_config( + const ParallelConfig& input_parallel_config, + const CoreCoord& compute_grid_size, + uint32_t out_channels, + bool is_mm_conv) { + ParallelConfig output_parallel_config = input_parallel_config; + if (input_parallel_config.shard_scheme == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !is_mm_conv) { + uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; + output_parallel_config = { + .grid = num_cores_to_corerangeset( + find_closest_largest_divisor_with_num_padding( + tt::div_up(out_channels, tt::constants::TILE_WIDTH), max_num_cores), + compute_grid_size, + true), + .shard_scheme = ttnn::TensorMemoryLayout::WIDTH_SHARDED, + .shard_orientation = input_parallel_config.shard_orientation}; + } + return output_parallel_config; +} + uint32_t get_num_cores_nhw_from_parallel_config(const ParallelConfig& pconfig) { TT_ASSERT(!pconfig.grid.ranges().empty()); TT_ASSERT( @@ -255,6 +295,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( const ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, + uint32_t padded_output_height_ntiles, uint32_t act_block_h_override, uint32_t act_block_w_div, uint32_t window_h, @@ -267,13 +308,24 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( act_block_h_override % 32 == 0, "Config Error: act_block_h_override must be a multiple of 32 (tile height)."); } - auto grid_size = parallel_config.grid.bounding_box().grid_size(); + uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; - if(act_block_h_override > 0) { + + if (act_block_h_override > 0) { if (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { log_info(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); } else { - act_block_h_ntiles = act_block_h_override / constants::TILE_HEIGHT; + uint32_t act_block_h_override_ntiles = act_block_h_override / constants::TILE_HEIGHT; + if (padded_output_height_ntiles % act_block_h_override_ntiles == 0) { + act_block_h_ntiles = act_block_h_override_ntiles; + } else { + log_info( + LogOp, + "act_block_h_override {} is not a valid override for padded_output_height_ntiles {}, override will " + "be ignored", + act_block_h_override_ntiles, + padded_output_height_ntiles); + } } } @@ -285,12 +337,13 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( } TT_ASSERT(act_block_w % 32 == 0); uint32_t act_block_w_ntiles = act_block_w / 32; + auto grid_size = parallel_config.grid.bounding_box().grid_size(); uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y : grid_size.x; uint32_t out_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntiles; - //act_block_h_ntiles / block_config.out_subblock_h_ntiles) >= 2 + auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled); return { @@ -325,15 +378,15 @@ static TensorMemoryLayout select_shard_spec( const CoreCoord& compute_grid_size) { auto get_core_count_for_sharding = [&](TensorMemoryLayout shard_layout) { return determine_parallel_config( - shard_layout, - batch_size, - in_channels, - output_height, - output_width, - out_channels, - compute_grid_size, - shard_orientation) - .grid.num_cores(); + shard_layout, + batch_size, + in_channels, + output_height, + output_width, + out_channels, + compute_grid_size, + shard_orientation, + !is_mm_conv).grid.num_cores(); }; // 1d convs support only height sharding @@ -350,7 +403,11 @@ static TensorMemoryLayout select_shard_spec( // Prefer block sharding over height sharding but make sure that we got at least // some blocking on width dimension as well. - if (cc_height > max_cc || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { + // Also for larger number of cores pefer block sharding, as it will divide weights along + // the cores. + const uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; + const uint32_t tree_quarter_cores = static_cast(0.75f * max_num_cores); + if ((cc_height > max_cc && max_cc < tree_quarter_cores) || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; max_cc = cc_height; } @@ -364,7 +421,7 @@ static TensorMemoryLayout select_shard_spec( // For large number of input channels prefer width sharding // even if it has less cores. // For BH we probably need to adjust this, or even better we make block sharding - // more configurable rearding l1 memory usage for weights. + // more configurable regarding L1 memory usage. if (cc_width >= 40 && in_channels > 1280) { shard_layout = TensorMemoryLayout::WIDTH_SHARDED; log_debug(LogOp, "Switching to WIDTH_SHARDED layout due to large in_channels"); @@ -385,7 +442,8 @@ std::tuple get_conv_padded_input_sh uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels) { + uint32_t out_channels, + bool is_mm_conv) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); bool needs_shard_or_reshard = false; @@ -457,7 +515,16 @@ std::tuple get_conv_padded_input_sh auto block_shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; ParallelConfig optimal_parallel_config = determine_parallel_config( - shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height); + shard_layout, + batch_size, + in_channels, + height, + width, + out_channels, + device->compute_with_storage_grid_size(), + block_shard_orientation, + !is_mm_conv, + !use_non_tile_height); if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Error"); @@ -478,6 +545,8 @@ std::tuple get_conv_padded_input_sh } if (needs_shard_or_reshard) { uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); + uint32_t input_num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); + // TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape()); uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; @@ -488,10 +557,8 @@ std::tuple get_conv_padded_input_sh } uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); - uint32_t tensor_width = input_tensor.get_shape()[3]; uint32_t input_tensor_width_snapped_to_channels_alignment = - tt::round_up(tensor_width, conv_config.input_channels_alignment); - TT_ASSERT(input_tensor_width_snapped_to_channels_alignment >= tensor_width); + tt::round_up(input_tensor.get_shape()[3], input_num_cores_c * conv_config.input_channels_alignment); auto input_padded_shape = ttnn::Shape(std::array{ 1, @@ -499,10 +566,12 @@ std::tuple get_conv_padded_input_sh input_tensor_height_snapped_to_tile, input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and // tt::tt_metal::LegacyShape issue to clean up next line - auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( + MemoryConfig input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{ - input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), - parallel_config, round_up_size); + input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), + parallel_config, + round_up_size); + return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height}; } else { return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard, use_non_tile_height}; @@ -533,23 +602,17 @@ std::tuple shard_or_re height, width, in_channels, - out_channels); + out_channels, + is_mm_conv); ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, .shard_orientation = input_tensor_sharded_memory_config.shard_spec.value().orientation }; - auto shard_layout = input_tensor_sharded_memory_config.memory_layout; - auto output_parallel_config = parallel_config; - if(shard_layout == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !is_mm_conv) { - uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; - output_parallel_config = { - .grid = num_cores_to_corerangeset( find_closest_largest_divisor(tt::div_up(out_channels, tt::constants::TILE_WIDTH),max_num_cores), compute_grid_size, true), - .shard_scheme = ttnn::TensorMemoryLayout::WIDTH_SHARDED, - .shard_orientation = parallel_config.shard_orientation - }; - log_debug(tt::LogOp, "Changing width sharded output grid to {}",output_parallel_config.grid); - } + + ParallelConfig output_parallel_config = + determine_output_parallel_config(parallel_config, compute_grid_size, out_channels, is_mm_conv); + if (needs_shard_or_reshard) { if (input_tensor.get_shape()[0] != 1 or input_tensor.get_shape()[1] != 1) { // reshape to [1, 1, N*H*W, C] @@ -672,9 +735,14 @@ std::pair> prepare_conv_weights_biases uint32_t in_channels = weights_shape[1]; uint32_t window_h = weights_shape[2]; uint32_t window_w = weights_shape[3]; - uint32_t out_channel_padding = tt::round_up(out_channels, 32) - out_channels; + + uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); + uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH); + uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment); + uint32_t out_channel_padding = out_channels_padded - out_channels; + tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( - {tt::round_up(out_channels, 32), tt::round_up(in_channels, input_channels_alignment), window_h, window_w})); + {out_channels_padded, in_channels_padded, window_h, window_w})); if (weights_bias_dtype == DataType::BFLOAT8_B) { TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); if (bias_tensor.has_value()) { @@ -865,18 +933,33 @@ Result conv2d( } uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; - auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), - output_parallel_config, round_up_size); - auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; - - auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( - conv_out_memory_config, get_num_cores_nhw_from_parallel_config(largest_parallel_config), + uint32_t nhw_out = batch_size * output_height * output_width; + uint32_t out_channels_padded = tt::round_up( + out_channels, + get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH); + MemoryConfig conv_out_memory_config = create_sharded_memory_config_from_parallel_config( + ttnn::Shape(std::array{1, 1, nhw_out, out_channels_padded}), + output_parallel_config, + round_up_size); + ParallelConfig largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; + + OptimizedConvParallelizationConfig opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( + conv_out_memory_config, + get_num_cores_nhw_from_parallel_config(largest_parallel_config), get_num_cores_channels_from_parallel_config(largest_parallel_config)); - auto opt_conv_op_block_config = determine_per_core_conv_block_config( + + uint32_t in_channels_padded = tt::round_up( + in_channels, + get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); + + uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) * + conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; + + OptimizedConvBlockConfig opt_conv_op_block_config = determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, - tt::round_up(in_channels, conv_config.input_channels_alignment), + in_channels_padded, + nhw_out_padded_ntile, conv_config.act_block_h_override, conv_config.act_block_w_div, kernel_size[0], @@ -1038,26 +1121,6 @@ Result conv2d( } } -template std::tuple get_conv_padded_input_shape_and_mem_config( - Device* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - -template std::tuple get_conv_padded_input_shape_and_mem_config( - MeshDevice * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - Result Conv2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index 53ba3cb04d8..403bd6d6f7c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -103,12 +103,6 @@ struct Conv2dConfig { } }; -uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor); - -uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor); - -uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor); - bool use_matmul_for_1x1_conv( const std::array& kernel_size, const std::array& stride, @@ -125,6 +119,7 @@ sliding_window::ParallelConfig determine_parallel_config( uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, + bool enable_channels_padding, bool is_out_tiled=true); uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); @@ -142,6 +137,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( const sliding_window::ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, + uint32_t padded_input_height_ntiles, uint32_t act_block_h_override, uint32_t act_block_w_div, uint32_t window_h, @@ -149,17 +145,6 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( bool fp32_accum, bool split_reader_enabled); -template -std::tuple get_conv_padded_input_shape_and_mem_config( - T * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - template std::tuple shard_or_reshard_tensor_if_required( T* device, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 4e37296bbaf..c1a3d3a79a8 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -4,6 +4,7 @@ +#include "common/constants.hpp" #include "ttnn/cpp/pybind11/decorators.hpp" #include "conv2d_pybind.hpp" @@ -120,65 +121,6 @@ void py_bind_conv2d(py::module& module) { py::arg("queue_id") = 0} ); - module.def( - "get_conv_padded_input_shape_and_mem_config", - [](ttnn::Device* device, - const ttnn::Tensor& input_tensor, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) -> std::tuple { - return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( - device, - input_tensor, - conv_config, - batch_size, - height, - width, - in_channels, - out_channels); - }, - py::kw_only(), - py::arg("device"), - py::arg("input_tensor"), - py::arg("conv_config"), - py::arg("batch_size"), - py::arg("height"), - py::arg("width"), - py::arg("in_channels"), - py::arg("out_channels")); - - module.def( - "get_conv_padded_input_shape_and_mem_config", - [](MeshDevice* device, - const ttnn::Tensor& input_tensor, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) -> std::tuple { - return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( - device, - input_tensor, - conv_config, - batch_size, - height, - width, - in_channels, - out_channels); - }, - py::kw_only(), - py::arg("device"), - py::arg("input_tensor"), - py::arg("conv_config"), - py::arg("batch_size"), - py::arg("height"), - py::arg("width"), - py::arg("in_channels"), - py::arg("out_channels")); module.def( "convert_conv_weight_tensor_to_tiled_layout", @@ -213,9 +155,10 @@ void py_bind_conv2d(py::module& module) { uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, + bool enable_channels_padding, bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { return ttnn::operations::conv::conv2d::determine_parallel_config( - shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_out_tiled); + shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, enable_channels_padding, is_out_tiled); }, py::arg("shard_layout"), py::arg("batch_size"), @@ -225,6 +168,7 @@ void py_bind_conv2d(py::module& module) { py::arg("output_channels"), py::arg("compute_grid_size"), py::arg("block_shard_orientation"), + py::arg("enable_channels_padding"), py::arg("is_out_tiled") = true); module.def( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index f0369bab66b..04cc1562a7b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -7,6 +7,7 @@ #include #include "conv2d_op.hpp" +#include "common/math.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/util.hpp" @@ -25,9 +26,12 @@ namespace optimized_conv_op_utils { using namespace tt; using namespace tt::tt_metal; -std::pair, std::vector> compute_opt_conv_activation_as_mm_shape(const tt::tt_metal::LegacyShape& conv_activation_shape, const ttnn::operations::sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t act_block_h_ntiles) { - - uint32_t filter_h = (uint32_t)sliding_window_config.window_hw.first; // filter_h +std::pair, std::vector> compute_opt_conv_activation_as_mm_shape( + const tt::tt_metal::LegacyShape& conv_activation_shape, + const ttnn::operations::sliding_window::SlidingWindowConfig& sliding_window_config, + uint32_t num_cores_nhw, + uint32_t act_block_h_ntiles) { + uint32_t filter_h = (uint32_t)sliding_window_config.window_hw.first; // filter_h uint32_t filter_w = (uint32_t)sliding_window_config.window_hw.second; // filter_W auto output_shape = sliding_window_config.get_output_shape(); uint32_t batch_size = output_shape[0]; @@ -35,11 +39,11 @@ std::pair, std::vector> compute_opt_conv_activat uint32_t conv_output_w = output_shape[2]; // pad height - uint32_t num_rows = (uint32_t) batch_size * conv_output_h * conv_output_w; + uint32_t num_rows = (uint32_t)batch_size * conv_output_h * conv_output_w; uint32_t act_block_h_datums = act_block_h_ntiles * TILE_HEIGHT; - uint32_t num_rows_padded = (uint32_t) (std::ceil((double) num_rows / (double) act_block_h_datums ) * act_block_h_datums); + uint32_t num_rows_padded = tt::round_up(num_rows, num_cores_nhw * act_block_h_datums); uint32_t num_cols = conv_activation_shape[3] * filter_h * filter_w; - uint32_t num_cols_padded = round_up(conv_activation_shape[3] * filter_w, TILE_WIDTH) * filter_h; + uint32_t num_cols_padded = tt::round_up(conv_activation_shape[3] * filter_w, TILE_WIDTH) * filter_h; return {{1, num_rows_padded, num_cols_padded}, {1, num_rows, num_cols}}; } @@ -107,7 +111,12 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const } if (this->memory_config.is_sharded()) { uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(input_tensor_a.get_legacy_shape(), sliding_window_config, out_block_h_ntiles); + auto [act_matrix_shape, act_matrix_shape_unpadded] = + optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( + input_tensor_a.get_legacy_shape(), + sliding_window_config, + parallelization_config.num_cores_nhw, + out_block_h_ntiles); uint32_t out_width_ntiles = this->compute_output_shapes(input_tensors).at(0)[-1] / TILE_WIDTH; if(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { TT_FATAL(this->parallelization_config.per_core_out_matrix_width_ntiles == out_width_ntiles, "Error"); @@ -116,9 +125,9 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const // For block sharded, out_width per core is shard width, and this is split along row // TODO: We should clean this up and relax constraints on out_subblock h and w if (this->memory_config.shard_spec.value().orientation == ShardOrientation::COL_MAJOR) { - out_width_ntiles /= this->parallelization_config.grid_size.y; + out_width_ntiles = tt::div_up(out_width_ntiles, this->parallelization_config.grid_size.y); } else { - out_width_ntiles /= this->parallelization_config.grid_size.x; + out_width_ntiles = tt::div_up(out_width_ntiles, this->parallelization_config.grid_size.x); } TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); } @@ -188,7 +197,12 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectordtype, output_layout, input_tensor.device(), mem_config)}; } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape(this->input_tensor_shape, sliding_window_config, this->parallelization_config.per_core_out_matrix_height_ntiles); + auto [act_matrix_shape, act_matrix_shape_unpadded] = + optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( + this->input_tensor_shape, + sliding_window_config, + this->parallelization_config.num_cores_nhw, + this->parallelization_config.per_core_out_matrix_height_ntiles); uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; uint32_t act_matrix_height_ntiles = act_matrix_height / TILE_HEIGHT; uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / this->parallelization_config.per_core_out_matrix_height_ntiles; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index e9e8ded5c9b..88151d5a83e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -183,7 +183,10 @@ namespace optimized_conv_op_utils { using namespace tt; using namespace tt::tt_metal; - -std::pair, std::vector> compute_opt_conv_activation_as_mm_shape(const tt::tt_metal::LegacyShape& conv_activation_shape, const ttnn::operations::sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t act_block_h_ntiles); +std::pair, std::vector> compute_opt_conv_activation_as_mm_shape( + const tt::tt_metal::LegacyShape& conv_activation_shape, + const ttnn::operations::sliding_window::SlidingWindowConfig& sliding_window_config, + uint32_t num_cores_nhw, + uint32_t act_block_h_ntiles); } // optimized_conv_op_utils diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index 1e3fcd69e54..de10fb342a7 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "common/math.hpp" #include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "tt_metal/common/work_split.hpp" @@ -565,7 +566,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( // Compute the 2d matrix shape auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( - ashape_with_channels_padded.value, sliding_window_config, out_block_h_ntiles); + ashape_with_channels_padded.value, + sliding_window_config, + parallelization_config.num_cores_nhw, + out_block_h_ntiles); assert(act_matrix_shape.size() == 3); assert(act_matrix_shape[0] == 1); uint32_t act_matrix_height = (uint32_t)act_matrix_shape[1]; @@ -887,8 +891,9 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( TT_FATAL(act_matrix_height_ntiles % per_core_out_matrix_height_ntiles == 0, "Error"); uint32_t total_active_num_cores_per_weight_slice; if (use_non_tile_height) { - uint32_t input_height_padded_per_core = shard_shape[0]; - total_active_num_cores_per_weight_slice = act_matrix_height / parallelization_config.per_core_out_matrix_height; + total_active_num_cores_per_weight_slice = + tt::round_up(act_matrix_height_unpadded, parallelization_config.num_cores_nhw) / + parallelization_config.per_core_out_matrix_height; } else { total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index 8799469070a..53f8e5bceab 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -225,7 +225,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( // Compute the 2d matrix shape auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( - ashape_with_channels_padded.value, sliding_window_config, out_block_h_ntiles); + ashape_with_channels_padded.value, + sliding_window_config, + parallelization_config.num_cores_nhw, + out_block_h_ntiles); TT_FATAL(act_matrix_shape.size() == 3, "Error"); TT_FATAL(act_matrix_shape[0] == 1, "Error"); uint32_t act_matrix_height = (uint32_t)act_matrix_shape[1]; @@ -364,13 +367,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( uint32_t num_groups = num_blocks_act_h * num_blocks_act_w * num_blocks_weight_w; // writer of conv op partially removes padding on the width // it removes the padding done for block width but it doesn't remove padding done for tiled width - uint32_t output_channels_padded_to_tile_width = round_up(output_channels, TILE_WIDTH); + uint32_t output_channels_padded_to_tile_width = round_up(output_channels, input_num_cores * TILE_WIDTH); TT_FATAL( output_channels_padded_to_tile_width <= weight_matrix_width, "output_channels_padded_to_tile_width {} should be less than or equal to weight_matrix_width {}", output_channels_padded_to_tile_width, weight_matrix_width); - uint32_t output_width_num_tiles = output_channels_padded_to_tile_width / TILE_WIDTH; uint32_t num_blocks_output_w = (uint32_t)std::ceil((double)output_channels_padded_to_tile_width / (double)weight_block_w_datums); uint32_t last_block_width_datums = (output_channels_padded_to_tile_width % weight_block_w_datums == 0) diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index cf42685fe88..d904d1d7cb1 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -213,21 +213,28 @@ Result conv_transpose2d( //Call Conv2d u_op with Stride = 1, Padding = 0. auto conv_out_memory_config = conv2d::create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), - output_parallel_config, - round_up_size); + ttnn::Shape(std::array{ + 1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + output_parallel_config, + round_up_size); auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; auto opt_conv_op_parallel_config = conv2d::determine_conv_op_parallel_config_from_conv_output_mem_config( conv_out_memory_config, conv2d::get_num_cores_nhw_from_parallel_config(largest_parallel_config), - conv2d::get_num_cores_channels_from_parallel_config(largest_parallel_config) - ); + conv2d::get_num_cores_channels_from_parallel_config(largest_parallel_config)); + + uint32_t in_channels_padded = tt::round_up( + in_channels, + conv2d::get_num_cores_channels_from_parallel_config(parallel_config) * + conv_config.input_channels_alignment); + auto opt_conv_op_block_config = conv2d::determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, - tt::round_up(in_channels, conv_config.input_channels_alignment), + in_channels_padded, + (input_tensor_post_tm.shard_spec().value().shape[0] * conv2d::get_num_cores_nhw_from_parallel_config(parallel_config)) / tt::constants::TILE_HEIGHT, conv_config.act_block_h_override, conv_config.act_block_w_div, kernel_size[0], diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index d6e38bc6b11..cad518759f0 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1268,9 +1268,8 @@ void Matmul::validate( uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; auto shard_shape = input_tensor_a.shard_spec().value().shape; - TT_FATAL( - div_up(M, per_core_M) == input_tensor_a.shard_spec().value().grid.num_cores(), "Error"); + div_up(M, per_core_M) <= input_tensor_a.shard_spec().value().grid.num_cores(), "Error"); TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); TT_FATAL(K % program_config.in0_block_w == 0, "Error"); TT_FATAL(K == (shard_shape[1] / in0_tile_shape[1]), "Error"); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp index 753164e4b8b..ee862c6ffac 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp @@ -55,15 +55,16 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, shard_layout = applied_shard_scheme.value(); } parallel_config = conv::conv2d::determine_parallel_config( - shard_layout, - batch_size, - channels, - output_shape[1], - output_shape[2], - channels, - input_tensor.device()->compute_with_storage_grid_size(), - ShardOrientation::ROW_MAJOR, - false); + shard_layout, + batch_size, + channels, + output_shape[1], + output_shape[2], + channels, + input_tensor.device()->compute_with_storage_grid_size(), + ShardOrientation::ROW_MAJOR, + false, + false); num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); num_cores_c = conv::conv2d::get_num_cores_channels_from_parallel_config(parallel_config); auto sharded_mem_config = conv::conv2d::create_sharded_memory_config_from_parallel_config(input_tensor_sharded.shape(), parallel_config, is_in_tiled ? tt::constants::TILE_HEIGHT : 1); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 31c31e7bcd5..9890dbb86ab 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -292,7 +292,7 @@ def auto_register_ttnn_cpp_operations(module): Topology, ) -from ttnn.operations.conv2d import Conv2dConfig, get_conv_padded_input_shape_and_mem_config, get_conv_output_dim +from ttnn.operations.conv2d import Conv2dConfig, get_conv_output_dim from ttnn.operations.conv1d import Conv1d, Conv1dConfig from ttnn.operations.transformer import SDPAProgramConfig diff --git a/ttnn/ttnn/operations/conv2d.py b/ttnn/ttnn/operations/conv2d.py index ca1f329dd69..b46a7e1fbf7 100644 --- a/ttnn/ttnn/operations/conv2d.py +++ b/ttnn/ttnn/operations/conv2d.py @@ -21,7 +21,6 @@ def _nearest_32(x): Conv2dConfig = ttnn._ttnn.operations.conv.Conv2dConfig -get_conv_padded_input_shape_and_mem_config = ttnn._ttnn.operations.conv.get_conv_padded_input_shape_and_mem_config OptimizedConvParallelizationConfig = ttnn._ttnn.operations.conv.OptimizedConvParallelizationConfig OptimizedConvBlockConfig = ttnn._ttnn.operations.conv.OptimizedConvBlockConfig