Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 8, 2024
1 parent d6a75ee commit 323f4d7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
5 changes: 5 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3823,6 +3823,7 @@ def test_discrete(self):
spec.enumerate()
== torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
).all()
assert spec.is_in(spec.enumerate())

def test_one_hot(self):
spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5))
Expand All @@ -3839,15 +3840,18 @@ def test_one_hot(self):
dtype=torch.bool,
)
).all()
assert spec.is_in(spec.enumerate())

def test_multi_discrete(self):
spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3))
enum = spec.enumerate()
assert spec.is_in(enum)
assert enum.shape == torch.Size([60, 2, 3])

def test_multi_onehot(self):
spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12))
enum = spec.enumerate()
assert spec.is_in(enum)
assert enum.shape == torch.Size([60, 2, 12])

def test_composite(self):
Expand All @@ -3859,6 +3863,7 @@ def test_composite(self):
shape=[3],
)
c_enum = c.enumerate()
assert c.is_in(c_enum)
assert c_enum.shape == torch.Size((20, 3))
assert c_enum["b"].shape == torch.Size((20, 3))

Expand Down
22 changes: 11 additions & 11 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
return self.is_in(item)

@abc.abstractmethod
def enumerate(self):
def enumerate(self) -> Any:
"""Returns all the samples that can be obtained from the TensorSpec.
The samples will be stacked along the first dimension.
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def __eq__(self, other):
return False
return True

def enumerate(self):
def enumerate(self) -> torch.Tensor | TensorDictBase:
return torch.stack(
[spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1
)
Expand Down Expand Up @@ -1747,7 +1747,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
return np.array(vals).reshape(tuple(val.shape))
return val

def enumerate(self):
def enumerate(self) -> torch.Tensor:
return (
torch.eye(self.n, dtype=self.dtype, device=self.device)
.expand(*self.shape, self.n)
Expand Down Expand Up @@ -2078,7 +2078,7 @@ def __init__(
domain=domain,
)

def enumerate(self):
def enumerate(self) -> Any:
raise NotImplementedError(
f"enumerate is not implemented for spec of class {type(self).__name__}."
)
Expand Down Expand Up @@ -2402,7 +2402,7 @@ def __init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)

def enumerate(self):
def enumerate(self) -> Any:
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
Expand Down Expand Up @@ -2641,7 +2641,7 @@ def is_in(self, val: torch.Tensor) -> bool:
def _project(self, val: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape)

def enumerate(self):
def enumerate(self) -> Any:
raise NotImplementedError("enumerate cannot be called with continuous specs.")

def expand(self, *shape):
Expand Down Expand Up @@ -2808,7 +2808,7 @@ def __init__(
)
self.update_mask(mask)

def enumerate(self):
def enumerate(self) -> torch.Tensor:
nvec = self.nvec
enum_disc = self.to_categorical_spec().enumerate()
enums = torch.cat(
Expand Down Expand Up @@ -3253,7 +3253,7 @@ def __init__(
)
self.update_mask(mask)

def enumerate(self):
def enumerate(self) -> torch.Tensor:
arange = torch.arange(self.n, dtype=self.dtype, device=self.device)
if self.ndim:
arange = arange.view(-1, *(1,) * self.ndim)
Expand Down Expand Up @@ -3766,7 +3766,7 @@ def __init__(
self.update_mask(mask)
self.remove_singleton = remove_singleton

def enumerate(self):
def enumerate(self) -> torch.Tensor:
if self.mask is not None:
raise RuntimeError(
"Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
Expand Down Expand Up @@ -4682,7 +4682,7 @@ def clone(self) -> Composite:
shape=self.shape,
)

def enumerate(self):
def enumerate(self) -> TensorDictBase:
# We are going to use meshgrid to create samples of all the subspecs in here
# but first let's get rid of the batch size, we'll put it back later
self_without_batch = self
Expand Down Expand Up @@ -4959,7 +4959,7 @@ def update(self, dict) -> None:
self[key] = item
return self

def enumerate(self):
def enumerate(self) -> TensorDictBase:
dim = self.stack_dim
return LazyStackedTensorDict.maybe_dense_stack(
[spec.enumerate() for spec in self._specs], dim + 1
Expand Down

0 comments on commit 323f4d7

Please sign in to comment.