Skip to content

Commit

Permalink
[mps] Massage test_full_truncation to work only on the supported dtyp…
Browse files Browse the repository at this point in the history
…es. (pytorch#144877)

Converted a first one to make sure the pattern was the one we wanted -- if we're OK with this, I'll probably adjust all the other failing ones in a batch or two. Let me know.

Pull Request resolved: pytorch#144877
Approved by: https://github.com/jansel, https://github.com/malfet
  • Loading branch information
dcci authored and pytorchmergebot committed Jan 16, 2025
1 parent 3d29de3 commit 1b34665
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/inductor/test_mps_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class MPSBasicTests(TestCase):
test_cat_empty = CommonTemplate.test_cat_empty
test_cat_unbacked_empty_1d = CommonTemplate.test_cat_unbacked_empty_1d
test_floordiv = CommonTemplate.test_floordiv
test_full_truncation = CommonTemplate.test_full_truncation
test_fmod = CommonTemplate.test_fmod
test_fmod_zero_dim = CommonTemplate.test_fmod_zero_dim
test_index_dynamic_shapes = CommonTemplate.test_index_dynamic_shapes
Expand Down
5 changes: 4 additions & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6554,8 +6554,11 @@ def test_full_truncation(self):
def fn(a):
return a + torch.full_like(a, 7.777)

device_interface = get_interface_for_device(self.device)
for dtype in all_types():
self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),))
ctx = contextlib.nullcontext() if device_interface.is_dtype_supported(dtype) else self.assertRaises(TypeError)
with ctx:
self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),), check_lowp=False)

def test_full_boolean(self):
def fn(n):
Expand Down

0 comments on commit 1b34665

Please sign in to comment.