From 4cc5332967bfb9a40d544a017ea1a1804e66681e Mon Sep 17 00:00:00 2001 From: Borys Bradel Date: Tue, 15 Oct 2024 17:19:59 +0000 Subject: [PATCH] #13204: add interleaved input sharded output matmul test --- .../ttnn/unit_tests/operations/test_matmul.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index f0aad38592be..e7b6652bdc68 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -1373,3 +1373,57 @@ def test_alternating_dst_sync_mode_matmul(device, M, K, N): assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) output_tensor = ttnn.to_torch(output3) assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + +def test_interleaved_input_sharded_output_matmul(device): + torch.manual_seed(0) + pcc = 0.99 + # Width sharded + torch_input_tensor_a = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16) + torch_input_tensor_b = torch.randn([1, 1, 32, 256], dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + + out_mem_config = ttnn.create_sharded_memory_config( + shape=(32, 256), + core_grid=ttnn.CoreGrid(x=1, y=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + output1 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output1) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + # Block sharded + out_mem_config = ttnn.create_sharded_memory_config( + shape=(32, 256), + core_grid=ttnn.CoreGrid(x=1, y=8), + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + output2 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output2) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + # Height sharded + torch_input_tensor_a = torch.randn([1, 1, 256, 32], dtype=torch.bfloat16) + torch_input_tensor_b = torch.randn([1, 1, 32, 32], dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + + out_mem_config = ttnn.create_sharded_memory_config( + shape=(256, 32), + core_grid=ttnn.CoreGrid(x=8, y=1), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output3) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)