diff --git a/models/experimental/functional_unet/tests/test_unet_perf.py b/models/experimental/functional_unet/tests/test_unet_perf.py index b97a24157726..2d35a23a1883 100644 --- a/models/experimental/functional_unet/tests/test_unet_perf.py +++ b/models/experimental/functional_unet/tests/test_unet_perf.py @@ -44,7 +44,7 @@ def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: flo inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" post_processed_results = run_device_perf( - command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=total_batch + command, subdir="unet_shallow", num_iterations=3, cols=cols, batch_size=total_batch ) expected_perf_cols = {inference_time_key: expected_device_perf_fps} expected_results = check_device_perf( diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index cc5a66404d0e..215399ea23ba 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -88,6 +88,7 @@ def __init__( activation_dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, output_layout=ttnn.TILE_LAYOUT, + reshard_if_not_optimal=False, mesh_mapper=None, ): self.device = device @@ -125,6 +126,7 @@ def __init__( activation=activation, output_layout=output_layout, input_channels_alignment=conv.input_channels_alignment if "input_channels_alignment" in conv else 32, + reshard_if_not_optimal=reshard_if_not_optimal, ) config_override = conv.conv_blocking_and_parallelization_config_override if config_override and "act_block_h" in config_override: @@ -192,11 +194,11 @@ def __init__( pool, device, conv_cache={}, - should_reshard=False, mesh_mapper=None, - nhw_core_override=-1, ): - self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) + self.conv1 = UNetConv2D( + conv1, bn=bn1, device=device, cache=conv_cache, reshard_if_not_optimal=True, mesh_mapper=mesh_mapper + ) self.conv2 = UNetConv2D( conv2, bn=bn2, @@ -205,8 +207,6 @@ def __init__( mesh_mapper=mesh_mapper, ) self.pool1 = UNetMaxPool2D(pool, conv2.out_channels, device=device) - if should_reshard: - self.conv1.conv_config.reshard_if_not_optimal = True def __call__(self, x): assert list(x.shape) == [ @@ -233,17 +233,13 @@ def __init__( bn3, device, conv_cache={}, - should_reshard=False, mesh_mapper=None, ): self.device = device - self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper) + self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, reshard_if_not_optimal=True, mesh_mapper=mesh_mapper) self.conv2 = UNetConv2D(conv2, bn2, device, conv_cache, mesh_mapper=mesh_mapper) self.conv3 = UNetConv2D(conv3, bn3, device, conv_cache, mesh_mapper=mesh_mapper) - if should_reshard: - self.conv1.conv_config.reshard_if_not_optimal = True - def upsample(self, x): # Need to reshape into (B, H, W, C) to get correct output from ttnn.upsample x = ttnn.reshape( @@ -308,7 +304,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p1, device, conv_cache=self.conv_cache, - should_reshard=False, mesh_mapper=mesh_mapper, ) self.downblock2 = UNetDownblock( @@ -319,7 +314,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p2, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.downblock3 = UNetDownblock( @@ -330,7 +324,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.downblock4 = UNetDownblock( @@ -341,12 +334,17 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.p4, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) - self.bnc = UNetConv2D(parameters.bnc, parameters.bnb, device, cache=self.conv_cache, mesh_mapper=mesh_mapper) - self.bnc.conv_config.reshard_if_not_optimal = True + self.bnc = UNetConv2D( + parameters.bnc, + parameters.bnb, + device, + cache=self.conv_cache, + reshard_if_not_optimal=True, + mesh_mapper=mesh_mapper, + ) self.bnc2 = UNetConv2D( parameters.bnc_2, parameters.bnb_2, device, cache=self.conv_cache, mesh_mapper=mesh_mapper ) @@ -360,7 +358,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b5_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.upblock2 = UNetUpblock( @@ -372,7 +369,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b6_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.upblock3 = UNetUpblock( @@ -384,7 +380,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b7_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, ) self.upblock4 = UNetUpblock( @@ -396,7 +391,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b8_3, device, conv_cache=self.conv_cache, - should_reshard=True, mesh_mapper=mesh_mapper, )