Skip to content

Commit

Permalink
#15171: PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavle Josipovic committed Nov 25, 2024
1 parent 632ec64 commit 4d0a28a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: flo

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
post_processed_results = run_device_perf(
command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=total_batch
command, subdir="unet_shallow", num_iterations=3, cols=cols, batch_size=total_batch
)
expected_perf_cols = {inference_time_key: expected_device_perf_fps}
expected_results = check_device_perf(
Expand Down
34 changes: 14 additions & 20 deletions models/experimental/functional_unet/tt/unet_shallow_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
activation_dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
output_layout=ttnn.TILE_LAYOUT,
reshard_if_not_optimal=False,
mesh_mapper=None,
):
self.device = device
Expand Down Expand Up @@ -125,6 +126,7 @@ def __init__(
activation=activation,
output_layout=output_layout,
input_channels_alignment=conv.input_channels_alignment if "input_channels_alignment" in conv else 32,
reshard_if_not_optimal=reshard_if_not_optimal,
)
config_override = conv.conv_blocking_and_parallelization_config_override
if config_override and "act_block_h" in config_override:
Expand Down Expand Up @@ -192,11 +194,11 @@ def __init__(
pool,
device,
conv_cache={},
should_reshard=False,
mesh_mapper=None,
nhw_core_override=-1,
):
self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper)
self.conv1 = UNetConv2D(
conv1, bn=bn1, device=device, cache=conv_cache, reshard_if_not_optimal=True, mesh_mapper=mesh_mapper
)
self.conv2 = UNetConv2D(
conv2,
bn=bn2,
Expand All @@ -205,8 +207,6 @@ def __init__(
mesh_mapper=mesh_mapper,
)
self.pool1 = UNetMaxPool2D(pool, conv2.out_channels, device=device)
if should_reshard:
self.conv1.conv_config.reshard_if_not_optimal = True

def __call__(self, x):
assert list(x.shape) == [
Expand All @@ -233,17 +233,13 @@ def __init__(
bn3,
device,
conv_cache={},
should_reshard=False,
mesh_mapper=None,
):
self.device = device
self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper)
self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, reshard_if_not_optimal=True, mesh_mapper=mesh_mapper)
self.conv2 = UNetConv2D(conv2, bn2, device, conv_cache, mesh_mapper=mesh_mapper)
self.conv3 = UNetConv2D(conv3, bn3, device, conv_cache, mesh_mapper=mesh_mapper)

if should_reshard:
self.conv1.conv_config.reshard_if_not_optimal = True

def upsample(self, x):
# Need to reshape into (B, H, W, C) to get correct output from ttnn.upsample
x = ttnn.reshape(
Expand Down Expand Up @@ -308,7 +304,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.p1,
device,
conv_cache=self.conv_cache,
should_reshard=False,
mesh_mapper=mesh_mapper,
)
self.downblock2 = UNetDownblock(
Expand All @@ -319,7 +314,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.p2,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)
self.downblock3 = UNetDownblock(
Expand All @@ -330,7 +324,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.p3,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)
self.downblock4 = UNetDownblock(
Expand All @@ -341,12 +334,17 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.p4,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)

self.bnc = UNetConv2D(parameters.bnc, parameters.bnb, device, cache=self.conv_cache, mesh_mapper=mesh_mapper)
self.bnc.conv_config.reshard_if_not_optimal = True
self.bnc = UNetConv2D(
parameters.bnc,
parameters.bnb,
device,
cache=self.conv_cache,
reshard_if_not_optimal=True,
mesh_mapper=mesh_mapper,
)
self.bnc2 = UNetConv2D(
parameters.bnc_2, parameters.bnb_2, device, cache=self.conv_cache, mesh_mapper=mesh_mapper
)
Expand All @@ -360,7 +358,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.b5_3,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)
self.upblock2 = UNetUpblock(
Expand All @@ -372,7 +369,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.b6_3,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)
self.upblock3 = UNetUpblock(
Expand All @@ -384,7 +380,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.b7_3,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)
self.upblock4 = UNetUpblock(
Expand All @@ -396,7 +391,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
parameters.b8_3,
device,
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
)

Expand Down

0 comments on commit 4d0a28a

Please sign in to comment.