Skip to content

Commit

Permalink
#13204: add interleaved input sharded output matmul test
Browse files Browse the repository at this point in the history
  • Loading branch information
bbradelTT committed Oct 15, 2024
1 parent e97f09e commit 4cc5332
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,3 +1373,57 @@ def test_alternating_dst_sync_mode_matmul(device, M, K, N):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)


def test_interleaved_input_sharded_output_matmul(device):
torch.manual_seed(0)
pcc = 0.99
# Width sharded
torch_input_tensor_a = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output1 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output1)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

# Block sharded
out_mem_config = ttnn.create_sharded_memory_config(
shape=(32, 256),
core_grid=ttnn.CoreGrid(x=1, y=8),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output2 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output2)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

# Height sharded
torch_input_tensor_a = torch.randn([1, 1, 256, 32], dtype=torch.bfloat16)
torch_input_tensor_b = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

out_mem_config = ttnn.create_sharded_memory_config(
shape=(256, 32),
core_grid=ttnn.CoreGrid(x=8, y=1),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)

output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output3)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)

0 comments on commit 4cc5332

Please sign in to comment.