Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 19, 2024
1 parent 6084e7b commit b033acc
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def test_equality_bounded(self):
minimum, maximum + 1, torch.Size((1,)), device, dtype
)
assert ts != ts_other
if torch.has_cuda:
if torch.cuda.device_count():
ts_other = BoundedTensorSpec(
minimum, maximum, torch.Size((1,)), "cuda:0", dtype
)
Expand Down Expand Up @@ -795,7 +795,7 @@ def test_equality_onehot(self):
)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = OneHotDiscreteTensorSpec(
n=n, device="cuda:0", dtype=dtype, use_register=use_register
)
Expand Down Expand Up @@ -825,7 +825,7 @@ def test_equality_unbounded(self):
ts_same = UnboundedContinuousTensorSpec(device=device, dtype=dtype)
assert ts == ts_same

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = UnboundedContinuousTensorSpec(device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -860,7 +860,7 @@ def test_equality_ndbounded(self):
)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = BoundedTensorSpec(
low=minimum, high=maximum, device="cuda:0", dtype=dtype
)
Expand Down Expand Up @@ -890,7 +890,7 @@ def test_equality_discrete(self):
ts_other = DiscreteTensorSpec(n=n + 1, shape=shape, device=device, dtype=dtype)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = DiscreteTensorSpec(
n=n, shape=shape, device="cuda:0", dtype=dtype
)
Expand Down Expand Up @@ -934,7 +934,7 @@ def test_equality_ndunbounded(self, shape):
)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = UnboundedContinuousTensorSpec(
shape=shape, device="cuda:0", dtype=dtype
)
Expand All @@ -948,7 +948,8 @@ def test_equality_ndunbounded(self, shape):
ts_other = TestEquality._ts_make_all_fields_equal(
BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
)
assert ts != ts_other
# Unbounded and bounded without space are technically the same
assert ts == ts_other

def test_equality_binary(self):
n = 5
Expand All @@ -963,7 +964,7 @@ def test_equality_binary(self):
ts_other = BinaryDiscreteTensorSpec(n=n + 5, device=device, dtype=dtype)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = BinaryDiscreteTensorSpec(n=n, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -1003,7 +1004,7 @@ def test_equality_multi_onehot(self, nvec):
)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = MultiOneHotDiscreteTensorSpec(
nvec=nvec, device="cuda:0", dtype=dtype
)
Expand Down Expand Up @@ -1041,7 +1042,7 @@ def test_equality_multi_discrete(self, nvec):
ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other

if torch.has_cuda:
if torch.cuda.device_count():
ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down

0 comments on commit b033acc

Please sign in to comment.