diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index cf684164a6..41775a4fcc 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -23,7 +23,7 @@ def forward(self, query, key, value): key = torch.rand(key_shape, dtype=torch.float16) value = torch.rand(key_shape, dtype=torch.float16) inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, precision=torch.float16) + self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) def test_sdpa_causal(self, query_shape, key_shape): @@ -38,7 +38,7 @@ def forward(self, query, key, value): key = torch.rand(key_shape, dtype=torch.float16) value = torch.rand(key_shape, dtype=torch.float16) inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, precision=torch.float16) + self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) @unittest.skipIf( @@ -69,6 +69,8 @@ def forward(self, query, key, value): self.run_test( SDPA(), inputs, + rtol=1e-2, + atol=1e-2, precision=torch.float16, enable_passes=True, ) @@ -99,6 +101,8 @@ def forward(self, query, key, value): self.run_test( SDPA(), inputs, + rtol=1e-2, + atol=1e-2, precision=torch.float16, enable_passes=True, )