From 9c2e869331e8e65f760579b6bf7b179ded0745a3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 4 Aug 2024 16:29:45 -0400 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- test/test_specs.py | 47 ++++++++++++++ torchrl/data/tensor_specs.py | 120 +++++++++++++++++++++++++++++++++-- 2 files changed, 161 insertions(+), 6 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 2d597d770f0..2a47d2680b9 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3740,6 +3740,53 @@ def test_device_ordinal(): assert spec.device == torch.device("cuda:0") +class TestSpecEnumerate: + def test_discrete(self): + spec = DiscreteTensorSpec(n=5, shape=(3,)) + assert ( + spec.enumerate() + == torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) + ).all() + + def test_one_hot(self): + spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5)) + assert ( + spec.enumerate() + == torch.tensor( + [ + [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + [[0, 1, 0, 0, 0], [0, 1, 0, 0, 0]], + [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], + [[0, 0, 0, 1, 0], [0, 0, 0, 1, 0]], + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 1]], + ], + dtype=torch.bool, + ) + ).all() + + def test_multi_discrete(self): + spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3)) + enum = spec.enumerate() + assert enum.shape == torch.Size([60, 2, 3]) + + def test_multi_onehot(self): + spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12)) + enum = spec.enumerate() + assert enum.shape == torch.Size([60, 2, 12]) + + def test_composite(self): + c = CompositeSpec( + { + "a": OneHotDiscreteTensorSpec(n=5, shape=(3, 5)), + ("b", "c"): DiscreteTensorSpec(n=4, shape=(3,)), + }, + shape=[3], + ) + c_enum = c.enumerate() + assert c_enum.shape == torch.Size((20, 3)) + assert c_enum["b"].shape == torch.Size((20, 3)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7c787b3ccfc..2afd6b1f3d6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -756,6 +756,16 @@ def contains(self, item): """ return self.is_in(item) + @abc.abstractmethod + def enumerate(self): + """Returns all the samples that can be obtained from the TensorSpec. + + The samples will be stacked along the first dimension. + + This method is only implemented for discrete specs. + """ + ... + def project(self, val: torch.Tensor) -> torch.Tensor: """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic. @@ -1152,6 +1162,11 @@ def __eq__(self, other): return False return True + def enumerate(self): + return torch.stack( + [spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1 + ) + def __len__(self): return self.shape[0] @@ -1601,6 +1616,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return np.array(vals).reshape(tuple(val.shape)) return val + def enumerate(self): + return ( + torch.eye(self.n, dtype=self.dtype, device=self.device) + .expand(*self.shape, self.n) + .permute(-2, *range(self.ndimension() - 1), -1) + ) + def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: if not isinstance(index, torch.Tensor): raise ValueError( @@ -1832,6 +1854,11 @@ def __init__( domain=domain, ) + def enumerate(self): + raise NotImplementedError( + f"enumerate is not implemented for spec of class {type(self).__name__}." + ) + def __eq__(self, other): return ( type(other) == type(self) @@ -2107,6 +2134,9 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) + def enumerate(self): + raise NotImplementedError("Cannot enumerate a NonTensorSpec.") + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -2273,6 +2303,9 @@ def is_in(self, val: torch.Tensor) -> bool: def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) + def enumerate(self): + raise NotImplementedError("enumerate cannot be called with continuous specs.") + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -2361,8 +2394,6 @@ class UnboundedDiscreteTensorSpec(TensorSpec): (should be an integer dtype such as long, uint8 etc.) """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -2409,6 +2440,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: return self return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) + def enumerate(self): + raise NotImplementedError("Cannot enumerate an unbounded tensor spec.") + def clone(self) -> UnboundedDiscreteTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) @@ -2553,8 +2587,6 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, nvec: Sequence[int], @@ -2586,6 +2618,18 @@ def __init__( ) self.update_mask(mask) + def enumerate(self): + nvec = self.nvec + enum_disc = self.to_categorical_spec().enumerate() + enums = torch.cat( + [ + torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype) + for nv, enum_unb in zip(nvec, enum_disc.unbind(-1)) + ], + -1, + ) + return enums + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -2975,6 +3019,12 @@ def __init__( ) self.update_mask(mask) + def enumerate(self): + arange = torch.arange(self.n, dtype=self.dtype, device=self.device) + if self.ndim: + arange = arange.view(-1, *(1,) * self.ndim) + return arange.expand(self.n, *self.shape) + @property def n(self): return self.space.n @@ -3428,6 +3478,29 @@ def __init__( self.update_mask(mask) self.remove_singleton = remove_singleton + def enumerate(self): + if self.mask is not None: + raise RuntimeError( + "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested." + ) + if self.nvec._base.ndim == 1: + nvec = self.nvec._base + else: + # we have to use unique() to isolate the nvec + nvec = self.nvec.view(-1, self.nvec.shape[-1]).unique(dim=0).squeeze(0) + if nvec.ndim > 1: + raise ValueError( + f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={nvec}." + ) + arange = torch.meshgrid( + *[torch.arange(n, device=self.device, dtype=self.dtype) for n in nvec], + indexing="ij", + ) + arange = torch.stack([arange_.reshape(-1) for arange_ in arange], dim=-1) + arange = arange.view(arange.shape[0], *(1,) * (self.ndim - 1), self.shape[-1]) + arange = arange.expand(arange.shape[0], *self.shape) + return arange + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -3646,6 +3719,8 @@ def to_one_hot( def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" + if self.ndim > 1: + return torch.stack([spec.to_one_hot_spec() for spec in self.unbind(0)]) nvec = [_space.n for _space in self.space] return MultiOneHotDiscreteTensorSpec( nvec, @@ -4297,6 +4372,33 @@ def clone(self) -> CompositeSpec: shape=self.shape, ) + def enumerate(self): + # We are going to use meshgrid to create samples of all the subspecs in here + # but first let's get rid of the batch size, we'll put it back later + self_without_batch = self + while self_without_batch.ndim: + self_without_batch = self_without_batch[0] + samples = {key: spec.enumerate() for key, spec in self_without_batch.items()} + if samples: + idx_rep = torch.meshgrid( + *(torch.arange(s.shape[0]) for s in samples.values()), indexing="ij" + ) + idx_rep = tuple(idx.reshape(-1) for idx in idx_rep) + samples = { + key: sample[idx] + for ((key, sample), idx) in zip(samples.items(), idx_rep) + } + samples = TensorDict( + samples, batch_size=idx_rep[0].shape[:1], device=self.device + ) + # Expand + if self.ndim: + samples = samples.reshape(-1, *(1,) * self.ndim) + samples = samples.expand(samples.shape[0], *self.shape) + else: + samples = TensorDict(batch_size=self.shape, device=self.device) + return samples + def empty(self): """Create a spec like self, but with no entries.""" try: @@ -4547,6 +4649,12 @@ def update(self, dict) -> None: self[key] = item return self + def enumerate(self): + dim = self.stack_dim + return LazyStackedTensorDict.maybe_dense_stack( + [spec.enumerate() for spec in self._specs], dim + 1 + ) + def __eq__(self, other): if not isinstance(other, LazyStackedCompositeSpec): return False @@ -4842,7 +4950,7 @@ def rand(self, shape=None) -> TensorDictBase: # for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack) -def _stack_specs(list_of_spec, dim, out=None): +def _stack_specs(list_of_spec, dim=0, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " @@ -4879,7 +4987,7 @@ def _stack_specs(list_of_spec, dim, out=None): @CompositeSpec.implements_for_spec(torch.stack) -def _stack_composite_specs(list_of_spec, dim, out=None): +def _stack_composite_specs(list_of_spec, dim=0, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " From 323f4d789ebb9f7bf713a0952c97aae08e7f0ecb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Nov 2024 14:40:56 +0000 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- test/test_specs.py | 5 +++++ torchrl/data/tensor_specs.py | 22 +++++++++++----------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 1a7dd41621e..39b09798ac2 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3823,6 +3823,7 @@ def test_discrete(self): spec.enumerate() == torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) ).all() + assert spec.is_in(spec.enumerate()) def test_one_hot(self): spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5)) @@ -3839,15 +3840,18 @@ def test_one_hot(self): dtype=torch.bool, ) ).all() + assert spec.is_in(spec.enumerate()) def test_multi_discrete(self): spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3)) enum = spec.enumerate() + assert spec.is_in(enum) assert enum.shape == torch.Size([60, 2, 3]) def test_multi_onehot(self): spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12)) enum = spec.enumerate() + assert spec.is_in(enum) assert enum.shape == torch.Size([60, 2, 12]) def test_composite(self): @@ -3859,6 +3863,7 @@ def test_composite(self): shape=[3], ) c_enum = c.enumerate() + assert c.is_in(c_enum) assert c_enum.shape == torch.Size((20, 3)) assert c_enum["b"].shape == torch.Size((20, 3)) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b641b808cf3..3590d76d2ce 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -835,7 +835,7 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool: return self.is_in(item) @abc.abstractmethod - def enumerate(self): + def enumerate(self) -> Any: """Returns all the samples that can be obtained from the TensorSpec. The samples will be stacked along the first dimension. @@ -1281,7 +1281,7 @@ def __eq__(self, other): return False return True - def enumerate(self): + def enumerate(self) -> torch.Tensor | TensorDictBase: return torch.stack( [spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1 ) @@ -1747,7 +1747,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return np.array(vals).reshape(tuple(val.shape)) return val - def enumerate(self): + def enumerate(self) -> torch.Tensor: return ( torch.eye(self.n, dtype=self.dtype, device=self.device) .expand(*self.shape, self.n) @@ -2078,7 +2078,7 @@ def __init__( domain=domain, ) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError( f"enumerate is not implemented for spec of class {type(self).__name__}." ) @@ -2402,7 +2402,7 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError("Cannot enumerate a NonTensorSpec.") def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: @@ -2641,7 +2641,7 @@ def is_in(self, val: torch.Tensor) -> bool: def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError("enumerate cannot be called with continuous specs.") def expand(self, *shape): @@ -2808,7 +2808,7 @@ def __init__( ) self.update_mask(mask) - def enumerate(self): + def enumerate(self) -> torch.Tensor: nvec = self.nvec enum_disc = self.to_categorical_spec().enumerate() enums = torch.cat( @@ -3253,7 +3253,7 @@ def __init__( ) self.update_mask(mask) - def enumerate(self): + def enumerate(self) -> torch.Tensor: arange = torch.arange(self.n, dtype=self.dtype, device=self.device) if self.ndim: arange = arange.view(-1, *(1,) * self.ndim) @@ -3766,7 +3766,7 @@ def __init__( self.update_mask(mask) self.remove_singleton = remove_singleton - def enumerate(self): + def enumerate(self) -> torch.Tensor: if self.mask is not None: raise RuntimeError( "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested." @@ -4682,7 +4682,7 @@ def clone(self) -> Composite: shape=self.shape, ) - def enumerate(self): + def enumerate(self) -> TensorDictBase: # We are going to use meshgrid to create samples of all the subspecs in here # but first let's get rid of the batch size, we'll put it back later self_without_batch = self @@ -4959,7 +4959,7 @@ def update(self, dict) -> None: self[key] = item return self - def enumerate(self): + def enumerate(self) -> TensorDictBase: dim = self.stack_dim return LazyStackedTensorDict.maybe_dense_stack( [spec.enumerate() for spec in self._specs], dim + 1