Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 19, 2024
1 parent 2e7b886 commit 6084e7b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
19 changes: 9 additions & 10 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ def __repr__(self):
def __eq__(self, other):
if other is None:

minval, maxval = _minmax_dtype(self.low.dtype, self.low.device)
minval, maxval = _minmax_dtype(self.low.dtype)
minval = torch.as_tensor(minval).to(self.low.device, self.low.dtype)
maxval = torch.as_tensor(maxval).to(self.low.device, self.low.dtype)
if (
torch.isclose(self.low, minval).all()
and torch.isclose(self.high, maxval).all()
Expand Down Expand Up @@ -1905,7 +1907,9 @@ def __eq__(self, other):
== other
)
if isinstance(other, BoundedTensorSpec):
minval, maxval = _minmax_dtype(self.dtype, self.device)
minval, maxval = _minmax_dtype(self.dtype)
minval = torch.as_tensor(minval).to(self.device, self.dtype)
maxval = torch.as_tensor(maxval).to(self.device, self.dtype)
return (
BoundedTensorSpec(
shape=self.shape,
Expand Down Expand Up @@ -4352,16 +4356,11 @@ def __contains__(self, item):
return False


def _minmax_dtype(dtype, device=None):
def _minmax_dtype(dtype):
if dtype is torch.bool:
return torch.tensor(False, device=device), torch.tensor(True, device=device)
return False, True
if dtype.is_floating_point:
info = torch.finfo(dtype)
else:
info = torch.iinfo(dtype)
if device is None:
return torch.as_tensor(info.min), torch.as_tensor(info.max)
else:
return torch.as_tensor(info.min).to(device), torch.as_tensor(info.max).to(
device
)
return info.min, info.max
8 changes: 5 additions & 3 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def _gym_to_torchrl_spec_transform(
high = torch.tensor(spec.high, device=device, dtype=dtype)
is_unbounded = low.isinf().all() and high.isinf().all()

minval, maxval = _minmax_dtype(dtype, device)
minval, maxval = _minmax_dtype(dtype)
minval = torch.as_tensor(minval).to(low.device, dtype)
maxval = torch.as_tensor(maxval).to(low.device, dtype)
is_unbounded = is_unbounded or (
torch.isclose(low, torch.tensor(minval, dtype=dtype)).all()
and torch.isclose(high, torch.tensor(maxval, dtype=dtype)).all()
Expand Down Expand Up @@ -428,15 +430,15 @@ def _torchrl_to_gym_spec_transform(
if isinstance(spec, OneHotDiscreteTensorSpec):
return gym_spaces.discrete.Discrete(spec.n)
if isinstance(spec, UnboundedContinuousTensorSpec):
minval, maxval = _minmax_dtype(spec.dtype, spec.device)
minval, maxval = _minmax_dtype(spec.dtype)
return gym_spaces.Box(
low=minval,
high=maxval,
shape=shape,
dtype=torch_to_numpy_dtype_dict[spec.dtype],
)
if isinstance(spec, UnboundedDiscreteTensorSpec):
minval, maxval = _minmax_dtype(spec.dtype, spec.device)
minval, maxval = _minmax_dtype(spec.dtype)
return gym_spaces.Box(
low=minval,
high=maxval,
Expand Down

0 comments on commit 6084e7b

Please sign in to comment.