From b033acc01b4ad590918a96e1100d56828ba14cce Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 19 Jan 2024 10:21:54 +0000 Subject: [PATCH] amend --- test/test_specs.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 2b4c8959f4f..cc97be11918 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -759,7 +759,7 @@ def test_equality_bounded(self): minimum, maximum + 1, torch.Size((1,)), device, dtype ) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = BoundedTensorSpec( minimum, maximum, torch.Size((1,)), "cuda:0", dtype ) @@ -795,7 +795,7 @@ def test_equality_onehot(self): ) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = OneHotDiscreteTensorSpec( n=n, device="cuda:0", dtype=dtype, use_register=use_register ) @@ -825,7 +825,7 @@ def test_equality_unbounded(self): ts_same = UnboundedContinuousTensorSpec(device=device, dtype=dtype) assert ts == ts_same - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = UnboundedContinuousTensorSpec(device="cuda:0", dtype=dtype) assert ts != ts_other @@ -860,7 +860,7 @@ def test_equality_ndbounded(self): ) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = BoundedTensorSpec( low=minimum, high=maximum, device="cuda:0", dtype=dtype ) @@ -890,7 +890,7 @@ def test_equality_discrete(self): ts_other = DiscreteTensorSpec(n=n + 1, shape=shape, device=device, dtype=dtype) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = DiscreteTensorSpec( n=n, shape=shape, device="cuda:0", dtype=dtype ) @@ -934,7 +934,7 @@ def test_equality_ndunbounded(self, shape): ) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = UnboundedContinuousTensorSpec( shape=shape, device="cuda:0", dtype=dtype ) @@ -948,7 +948,8 @@ def test_equality_ndunbounded(self, shape): ts_other = TestEquality._ts_make_all_fields_equal( BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts ) - assert ts != ts_other + # Unbounded and bounded without space are technically the same + assert ts == ts_other def test_equality_binary(self): n = 5 @@ -963,7 +964,7 @@ def test_equality_binary(self): ts_other = BinaryDiscreteTensorSpec(n=n + 5, device=device, dtype=dtype) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = BinaryDiscreteTensorSpec(n=n, device="cuda:0", dtype=dtype) assert ts != ts_other @@ -1003,7 +1004,7 @@ def test_equality_multi_onehot(self, nvec): ) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = MultiOneHotDiscreteTensorSpec( nvec=nvec, device="cuda:0", dtype=dtype ) @@ -1041,7 +1042,7 @@ def test_equality_multi_discrete(self, nvec): ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other - if torch.has_cuda: + if torch.cuda.device_count(): ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cuda:0", dtype=dtype) assert ts != ts_other