diff --git a/test/test_specs.py b/test/test_specs.py index ea3cbfe069d..39b09798ac2 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3816,6 +3816,58 @@ def test_non_tensor(self): assert not isinstance(non_tensor, MultiOneHot) +class TestSpecEnumerate: + def test_discrete(self): + spec = DiscreteTensorSpec(n=5, shape=(3,)) + assert ( + spec.enumerate() + == torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) + ).all() + assert spec.is_in(spec.enumerate()) + + def test_one_hot(self): + spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5)) + assert ( + spec.enumerate() + == torch.tensor( + [ + [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + [[0, 1, 0, 0, 0], [0, 1, 0, 0, 0]], + [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], + [[0, 0, 0, 1, 0], [0, 0, 0, 1, 0]], + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 1]], + ], + dtype=torch.bool, + ) + ).all() + assert spec.is_in(spec.enumerate()) + + def test_multi_discrete(self): + spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3)) + enum = spec.enumerate() + assert spec.is_in(enum) + assert enum.shape == torch.Size([60, 2, 3]) + + def test_multi_onehot(self): + spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12)) + enum = spec.enumerate() + assert spec.is_in(enum) + assert enum.shape == torch.Size([60, 2, 12]) + + def test_composite(self): + c = CompositeSpec( + { + "a": OneHotDiscreteTensorSpec(n=5, shape=(3, 5)), + ("b", "c"): DiscreteTensorSpec(n=4, shape=(3,)), + }, + shape=[3], + ) + c_enum = c.enumerate() + assert c.is_in(c_enum) + assert c_enum.shape == torch.Size((20, 3)) + assert c_enum["b"].shape == torch.Size((20, 3)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 98a32de5715..3590d76d2ce 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -834,6 +834,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool: """ return self.is_in(item) + @abc.abstractmethod + def enumerate(self) -> Any: + """Returns all the samples that can be obtained from the TensorSpec. + + The samples will be stacked along the first dimension. + + This method is only implemented for discrete specs. + """ + ... + def project( self, val: torch.Tensor | TensorDictBase ) -> torch.Tensor | TensorDictBase: @@ -1271,6 +1281,11 @@ def __eq__(self, other): return False return True + def enumerate(self) -> torch.Tensor | TensorDictBase: + return torch.stack( + [spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1 + ) + def __len__(self): return self.shape[0] @@ -1732,6 +1747,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return np.array(vals).reshape(tuple(val.shape)) return val + def enumerate(self) -> torch.Tensor: + return ( + torch.eye(self.n, dtype=self.dtype, device=self.device) + .expand(*self.shape, self.n) + .permute(-2, *range(self.ndimension() - 1), -1) + ) + def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: if not isinstance(index, torch.Tensor): raise ValueError( @@ -2056,6 +2078,11 @@ def __init__( domain=domain, ) + def enumerate(self) -> Any: + raise NotImplementedError( + f"enumerate is not implemented for spec of class {type(self).__name__}." + ) + def __eq__(self, other): return ( type(other) == type(self) @@ -2375,6 +2402,9 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) + def enumerate(self) -> Any: + raise NotImplementedError("Cannot enumerate a NonTensorSpec.") + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -2611,6 +2641,9 @@ def is_in(self, val: torch.Tensor) -> bool: def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) + def enumerate(self) -> Any: + raise NotImplementedError("enumerate cannot be called with continuous specs.") + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -2775,6 +2808,18 @@ def __init__( ) self.update_mask(mask) + def enumerate(self) -> torch.Tensor: + nvec = self.nvec + enum_disc = self.to_categorical_spec().enumerate() + enums = torch.cat( + [ + torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype) + for nv, enum_unb in zip(nvec, enum_disc.unbind(-1)) + ], + -1, + ) + return enums + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -3208,6 +3253,12 @@ def __init__( ) self.update_mask(mask) + def enumerate(self) -> torch.Tensor: + arange = torch.arange(self.n, dtype=self.dtype, device=self.device) + if self.ndim: + arange = arange.view(-1, *(1,) * self.ndim) + return arange.expand(self.n, *self.shape) + @property def n(self): return self.space.n @@ -3715,6 +3766,29 @@ def __init__( self.update_mask(mask) self.remove_singleton = remove_singleton + def enumerate(self) -> torch.Tensor: + if self.mask is not None: + raise RuntimeError( + "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested." + ) + if self.nvec._base.ndim == 1: + nvec = self.nvec._base + else: + # we have to use unique() to isolate the nvec + nvec = self.nvec.view(-1, self.nvec.shape[-1]).unique(dim=0).squeeze(0) + if nvec.ndim > 1: + raise ValueError( + f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={nvec}." + ) + arange = torch.meshgrid( + *[torch.arange(n, device=self.device, dtype=self.dtype) for n in nvec], + indexing="ij", + ) + arange = torch.stack([arange_.reshape(-1) for arange_ in arange], dim=-1) + arange = arange.view(arange.shape[0], *(1,) * (self.ndim - 1), self.shape[-1]) + arange = arange.expand(arange.shape[0], *self.shape) + return arange + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -3932,6 +4006,8 @@ def to_one_hot( def to_one_hot_spec(self) -> MultiOneHot: """Converts the spec to the equivalent one-hot spec.""" + if self.ndim > 1: + return torch.stack([spec.to_one_hot_spec() for spec in self.unbind(0)]) nvec = [_space.n for _space in self.space] return MultiOneHot( nvec, @@ -4606,6 +4682,33 @@ def clone(self) -> Composite: shape=self.shape, ) + def enumerate(self) -> TensorDictBase: + # We are going to use meshgrid to create samples of all the subspecs in here + # but first let's get rid of the batch size, we'll put it back later + self_without_batch = self + while self_without_batch.ndim: + self_without_batch = self_without_batch[0] + samples = {key: spec.enumerate() for key, spec in self_without_batch.items()} + if samples: + idx_rep = torch.meshgrid( + *(torch.arange(s.shape[0]) for s in samples.values()), indexing="ij" + ) + idx_rep = tuple(idx.reshape(-1) for idx in idx_rep) + samples = { + key: sample[idx] + for ((key, sample), idx) in zip(samples.items(), idx_rep) + } + samples = TensorDict( + samples, batch_size=idx_rep[0].shape[:1], device=self.device + ) + # Expand + if self.ndim: + samples = samples.reshape(-1, *(1,) * self.ndim) + samples = samples.expand(samples.shape[0], *self.shape) + else: + samples = TensorDict(batch_size=self.shape, device=self.device) + return samples + def empty(self): """Create a spec like self, but with no entries.""" try: @@ -4856,6 +4959,12 @@ def update(self, dict) -> None: self[key] = item return self + def enumerate(self) -> TensorDictBase: + dim = self.stack_dim + return LazyStackedTensorDict.maybe_dense_stack( + [spec.enumerate() for spec in self._specs], dim + 1 + ) + def __eq__(self, other): if not isinstance(other, StackedComposite): return False @@ -5150,7 +5259,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: @TensorSpec.implements_for_spec(torch.stack) -def _stack_specs(list_of_spec, dim, out=None): +def _stack_specs(list_of_spec, dim=0, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " @@ -5187,7 +5296,7 @@ def _stack_specs(list_of_spec, dim, out=None): @Composite.implements_for_spec(torch.stack) -def _stack_composite_specs(list_of_spec, dim, out=None): +def _stack_composite_specs(list_of_spec, dim=0, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence "