Skip to content

Commit

Permalink
#6991: Update ttnn unit test files for non-nan
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Apr 12, 2024
1 parent fc89f66 commit a8df27d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 20 deletions.
3 changes: 0 additions & 3 deletions tests/ttnn/unit_tests/operations/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
def test_mean(device, batch_size, h, w, dim):
torch.manual_seed(0)

if is_wormhole_b0() and dim == -2:
pytest.skip("Issue #6991: Wormhole B0: mean operation fails for dim=-2")

torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16)

Expand Down
4 changes: 0 additions & 4 deletions tests/ttnn/unit_tests/operations/test_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
@pytest.mark.parametrize("w", [32, 64])
@pytest.mark.parametrize("dim", [-1, -2])
def test_min(device, batch_size, h, w, dim):
if is_wormhole_b0() and dim == -2:
pytest.skip("Issue #6991: PCC mismatch for dim=-2")
torch.manual_seed(0)

torch_input_tensor = torch_random((batch_size, h, w), -100, 100, dtype=torch.bfloat16)
Expand All @@ -37,8 +35,6 @@ def test_min(device, batch_size, h, w, dim):
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
def test_min_global(device, batch_size, h, w):
if is_wormhole_b0():
pytest.skip("Issue #6991: PCC mismatch")
torch.manual_seed(0)

torch_input_tensor = torch_random((batch_size, h, w), -100, 100, dtype=torch.bfloat16)
Expand Down
4 changes: 0 additions & 4 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
@pytest.mark.parametrize("dim", [-1, -2])
def test_std(device, batch_size, h, w, dim):
torch.manual_seed(0)
if is_wormhole_b0() and dim == -2:
pytest.skip("Issue #6991: PCC mismatch for dim=-2")

torch_input_tensor = torch.randn((batch_size, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.std(torch_input_tensor, dim=dim, keepdim=True)
Expand All @@ -39,8 +37,6 @@ def test_std(device, batch_size, h, w, dim):
@pytest.mark.parametrize("dim", [-1, -2])
def test_var(device, batch_size, h, w, dim):
torch.manual_seed(0)
if is_wormhole_b0() and dim == -2:
pytest.skip("Issue #6991: PCC mismatch for dim=-2")

torch_input_tensor = torch.randn((batch_size, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.var(torch_input_tensor, dim=dim, keepdim=True)
Expand Down
4 changes: 0 additions & 4 deletions tests/ttnn/unit_tests/operations/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
@pytest.mark.parametrize("dim", [-1, -2, (2, 1)])
def test_sum(device, batch_size, h, w, dim):
torch.manual_seed(0)
if is_wormhole_b0():
pytest.skip("Issue #6991: PCC mismatch")

torch_input_tensor = torch_random((batch_size, h, w), -100, 100, dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=True)
Expand All @@ -38,8 +36,6 @@ def test_sum(device, batch_size, h, w, dim):
@pytest.mark.parametrize("w", [32, 64])
def test_sum_global(device, batch_size, h, w):
torch.manual_seed(0)
if is_wormhole_b0():
pytest.skip("Issue #6991: PCC mismatch")

torch_input_tensor = torch_random((batch_size, h, w), -100, 100, dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor)
Expand Down
21 changes: 16 additions & 5 deletions tests/ttnn/unit_tests/operations/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def test_cosh(device, h, w):

@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
@skip_for_wormhole_b0("Issue #6991: Failing on wormhole_b0 PCC issue")
def test_acosh(device, h, w):
run_unary_test(device, h, w, ttnn.acosh, torch.acosh)
run_unary_test_with_range(device, h, w, 1, 100, ttnn.acosh, torch.acosh)


@pytest.mark.parametrize("h", [64])
Expand Down Expand Up @@ -212,9 +211,21 @@ def run_unary_test_with_float(device, h, w, scalar, ttnn_function, torch_functio
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


@pytest.mark.parametrize("scalar", [1, 2])
def run_unary_test_with_range(device, h, w, scalar, low, high, ttnn_function, torch_function, pcc=0.9999):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) * (high - low) + low
torch_output_tensor = torch_function(torch_input_tensor, scalar)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor, scalar)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, pcc)


@pytest.mark.parametrize("scalar", [1, 1e-6])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
@skip_for_wormhole_b0("Issue #6991: Failing on wormhole_b0 PCC issue")
def test_logit(device, h, w, scalar):
run_unary_test_with_float(device, h, w, scalar, ttnn.logit, torch.logit)
run_unary_test_with_range(device, h, w, scalar, ttnn.logit, torch.logit)

0 comments on commit a8df27d

Please sign in to comment.