Skip to content

Commit

Permalink
Increase tolerance bound for FP16
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Apr 26, 2024
1 parent de5499c commit 62878a3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/py/dynamo/conversion/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, query, key, value):
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(SDPA(), inputs, precision=torch.float16)
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)

@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
Expand All @@ -38,7 +38,7 @@ def forward(self, query, key, value):
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(SDPA(), inputs, precision=torch.float16)
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)


@unittest.skipIf(
Expand Down Expand Up @@ -69,6 +69,8 @@ def forward(self, query, key, value):
self.run_test(
SDPA(),
inputs,
rtol=1e-2,
atol=1e-2,
precision=torch.float16,
enable_passes=True,
)
Expand Down Expand Up @@ -99,6 +101,8 @@ def forward(self, query, key, value):
self.run_test(
SDPA(),
inputs,
rtol=1e-2,
atol=1e-2,
precision=torch.float16,
enable_passes=True,
)
Expand Down

0 comments on commit 62878a3

Please sign in to comment.