-
Notifications
You must be signed in to change notification settings - Fork 313
/
tensor_specs.py
5822 lines (5039 loc) · 203 KB
/
tensor_specs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import abc
import enum
import math
import warnings
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
from textwrap import indent
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
overload,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy as np
import tensordict
import torch
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.base import NO_DEFAULT
from tensordict.utils import _getitem_batch_size, NestedKey
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for
DEVICE_TYPING = Union[torch.device, str, int]
INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]
SHAPE_INDEX_TYPING = Union[
int,
range,
List[int],
np.ndarray,
slice,
None,
torch.Tensor,
type(...),
Tuple[
int,
range,
List[int],
np.ndarray,
slice,
None,
torch.Tensor,
type(...),
Tuple[Any],
],
]
# By default, we do not check that an obs is in the domain. THis should be done when validating the env beforehand
_CHECK_SPEC_ENCODE = get_binary_env_var("CHECK_SPEC_ENCODE")
_DEFAULT_SHAPE = torch.Size((1,))
DEVICE_ERR_MSG = "device of empty Composite is not defined."
NOT_IMPLEMENTED_ERROR = NotImplementedError(
"method is not currently implemented."
" If you are interested in this feature please submit"
" an issue at https://github.com/pytorch/rl/issues"
)
def _size(list_of_ints):
# ensures that np int64 elements don't slip through Size
# see https://github.com/pytorch/pytorch/issues/127194
return torch.Size([int(i) for i in list_of_ints])
# Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default
class _NoDefault(enum.IntEnum):
ZERO = 0
ONE = 1
NO_DEFAULT_RL = _NoDefault.ONE
def _default_dtype_and_device(
dtype: Union[None, torch.dtype],
device: Union[None, str, int, torch.device],
allow_none_device: bool = False,
) -> Tuple[torch.dtype, torch.device | None]:
if dtype is None:
dtype = torch.get_default_dtype()
if device is not None:
device = _make_ordinal_device(torch.device(device))
elif not allow_none_device:
device = torch.zeros(()).device
return dtype, device
def _validate_idx(shape: list[int], idx: int, axis: int = 0):
"""Raise an IndexError if idx is out of bounds for shape[axis].
Args:
shape (list[int]): Input shape
idx (int): Index, may be negative
axis (int): Shape axis to check
"""
if shape[axis] >= 0 and (idx >= shape[axis] or idx < 0 and -idx > shape[axis]):
raise IndexError(
f"index {idx} is out of bounds for axis {axis} with size {shape[axis]}"
)
def _validate_iterable(
idx: Iterable[Any], expected_type: type, iterable_classname: str
):
"""Raise an IndexError if the iterable contains a type different from the expected type or Iterable.
Args:
idx (Iterable[Any]): Iterable, may contain nested iterables
expected_type (type): Required item type in the Iterable (e.g. int)
iterable_classname (str): Iterable type as a string (e.g. 'List'). Logging purpose only.
"""
for item in idx:
if isinstance(item, Iterable):
_validate_iterable(item, expected_type, iterable_classname)
else:
if not isinstance(item, expected_type):
raise IndexError(
f"{iterable_classname} indexing expects {expected_type} indices"
)
def _slice_indexing(shape: list[int], idx: slice) -> List[int]:
"""Given an input shape and a slice index, returns the new indexed shape.
Args:
shape (list[int]): Input shape
idx (slice): Index
Returns:
Indexed shape
Examples:
>>> _slice_indexing([3, 4], slice(None, 2))
[2, 4]
>>> list(torch.rand(3, 4)[:2].shape)
[2, 4]
"""
if idx.step == 0:
raise ValueError("slice step cannot be zero")
# Slicing an empty shape returns the shape
if len(shape) == 0:
return shape
if idx.start is None:
start = 0
else:
start = idx.start if idx.start >= 0 else max(shape[0] + idx.start, 0)
if idx.stop is None:
stop = shape[0]
else:
stop = idx.stop if idx.stop >= 0 else max(shape[0] + idx.stop, 0)
step = 1 if idx.step is None else idx.step
if step > 0:
if start >= stop:
n_items = 0
else:
stop = min(stop, shape[0])
n_items = math.ceil((stop - start) / step)
else:
if start <= stop:
n_items = 0
else:
start = min(start, shape[0] - 1)
n_items = math.ceil((stop - start) / step)
return [n_items] + shape[1:]
def _shape_indexing(
shape: Union[list[int], torch.Size, Tuple[int]], idx: SHAPE_INDEX_TYPING
) -> List[int]:
"""Given an input shape and an index, returns the size of the resulting indexed spec.
This function includes indexing checks and may raise IndexErrors.
Args:
shape (list[int], torch.Size, Tuple[int): Input shape
idx (SHAPE_INDEX_TYPING): Index
Returns:
Shape of the resulting spec
Examples:
>>> idx = (2, ..., None)
>>> Categorical(2, shape=(3, 4))[idx].shape
torch.Size([4, 1])
>>> _shape_indexing([3, 4], idx)
torch.Size([4, 1])
"""
if not isinstance(shape, list):
shape = list(shape)
if idx is Ellipsis or (
isinstance(idx, slice) and (idx.step is idx.start is idx.stop is None)
):
return shape
if idx is None:
return [1] + shape
if len(shape) == 0 and (
isinstance(idx, int)
or isinstance(idx, range)
or isinstance(idx, list)
and len(idx) > 0
):
raise IndexError(
f"cannot use integer indices on 0-dimensional shape. `{idx}` received"
)
if isinstance(idx, int):
_validate_idx(shape, idx)
return shape[1:]
if isinstance(idx, range):
if len(idx) > 0 and (idx.start >= shape[0] or idx.stop > shape[0]):
raise IndexError(f"index out of bounds for axis 0 with size {shape[0]}")
return [len(idx)] + shape[1:]
if isinstance(idx, slice):
return _slice_indexing(shape, idx)
if isinstance(idx, tuple):
# Supports int, None, slice and ellipsis indices
# Index on the current shape dimension
shape_idx = 0
none_dims = 0
ellipsis = False
prev_is_list = False
shape_len = len(shape)
for item_idx, item in enumerate(idx):
if item is None:
shape = shape[:shape_idx] + [1] + shape[shape_idx:]
shape_idx += 1
none_dims += 1
elif isinstance(item, int):
_validate_idx(shape, item, shape_idx)
del shape[shape_idx]
elif isinstance(item, slice):
shape[shape_idx] = _slice_indexing([shape[shape_idx]], item)[0]
shape_idx += 1
elif item is Ellipsis:
if ellipsis:
raise IndexError("an index can only have a single ellipsis (`...`)")
# Move to the end of the shape, subtracted by the number of future indices impacting the dimensions (i.e. all except None and ...)
shape_idx = len(shape) - len(
[i for i in idx[item_idx + 1 :] if not (i is None or i is Ellipsis)]
)
ellipsis = True
elif any(
isinstance(item, _type)
for _type in [list, tuple, range, np.ndarray, torch.Tensor]
):
while isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]
# Nested tuples are handled as a list. Numpy behavior
if isinstance(item, tuple):
item = list(item)
if prev_is_list and isinstance(item, list):
del shape[shape_idx]
continue
if isinstance(item, list):
prev_is_list = True
if shape_idx >= len(shape):
raise IndexError("Raise IndexError: too many indices for array")
res = _shape_indexing([shape[shape_idx]], item)
shape = shape[:shape_idx] + res + shape[shape_idx + 1 :]
shape_idx += len(res)
else:
raise IndexError(
f"tuple indexing only supports integers, ranges, slices (`:`), ellipsis (`...`), new axis (`None`), tuples, list, tensor and ndarray indices. {str(type(idx))} received"
)
if len(idx) - none_dims - int(ellipsis) > shape_len:
raise IndexError(
f"shape is {shape_len}-dimensional, but {len(idx) - none_dims - int(ellipsis)} dimensions were indexed"
)
return shape
if isinstance(idx, list):
# int indexing only
_validate_iterable(idx, int, "list")
for item in np.array(idx).reshape(-1):
_validate_idx(shape, item, 0)
return list(np.array(idx).shape) + shape[1:]
if isinstance(idx, np.ndarray) or isinstance(idx, torch.Tensor):
# Out of bounds check
for item in idx.reshape(-1):
_validate_idx(shape, item)
return list(_getitem_batch_size(shape, idx))
class invertible_dict(dict):
"""An invertible dictionary.
Examples:
>>> my_dict = invertible_dict(a=3, b=2)
>>> inv_dict = my_dict.invert()
>>> assert {2, 3} == set(inv_dict.keys())
"""
def __init__(self, *args, inv_dict=None, **kwargs):
if inv_dict is None:
inv_dict = {}
super().__init__(*args, **kwargs)
self.inv_dict = inv_dict
def __setitem__(self, k, v):
if v in self.inv_dict or k in self:
raise Exception("overwriting in invertible_dict is not permitted")
self.inv_dict[v] = k
return super().__setitem__(k, v)
def update(self, d):
raise NotImplementedError
def invert(self):
d = invertible_dict()
for k, value in self.items():
d[value] = k
return d
def inverse(self):
return self.inv_dict
class Box:
"""A box of values."""
def __iter__(self):
raise NotImplementedError
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
raise NotImplementedError
def __repr__(self):
return f"{self.__class__.__name__}()"
def clone(self) -> CategoricalBox:
return deepcopy(self)
@dataclass(repr=False)
class ContinuousBox(Box):
"""A continuous box of values, in between a minimum (self.low) and a maximum (self.high)."""
_low: torch.Tensor
_high: torch.Tensor
device: torch.device | None = None
# We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
@property
def low(self):
return self._low.to(self.device)
@property
def high(self):
return self._high.to(self.device)
def unbind(self, dim: int = 0):
return tuple(
type(self)(low, high, self.device)
for (low, high) in zip(self.low.unbind(dim), self.high.unbind(dim))
)
@low.setter
def low(self, value):
self.device = value.device
self._low = value.cpu()
@high.setter
def high(self, value):
self.device = value.device
self._high = value.cpu()
def __post_init__(self):
self.low = self.low.clone()
self.high = self.high.clone()
def __iter__(self):
yield self.low
yield self.high
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
return self.__class__(self.low.to(dest), self.high.to(dest))
def clone(self) -> ContinuousBox:
return self.__class__(self.low.clone(), self.high.clone())
def __repr__(self):
min_str = indent(
f"\nlow=Tensor(shape={self.low.shape}, device={self.low.device}, dtype={self.low.dtype}, contiguous={self.high.is_contiguous()})",
" " * 4,
)
max_str = indent(
f"\nhigh=Tensor(shape={self.high.shape}, device={self.high.device}, dtype={self.high.dtype}, contiguous={self.high.is_contiguous()})",
" " * 4,
)
return f"{self.__class__.__name__}({min_str},{max_str})"
def __eq__(self, other):
if other is None:
minval, maxval = _minmax_dtype(self.low.dtype)
minval = torch.as_tensor(minval).to(self.low.device, self.low.dtype)
maxval = torch.as_tensor(maxval).to(self.low.device, self.low.dtype)
if (
torch.isclose(self.low, minval).all()
and torch.isclose(self.high, maxval).all()
):
return True
if (
not torch.isfinite(self.low).any()
and not torch.isfinite(self.high).any()
):
return True
return False
return (
type(self) == type(other)
and self.low.dtype == other.low.dtype
and self.high.dtype == other.high.dtype
and self.device == other.device
and torch.isclose(self.low, other.low).all()
and torch.isclose(self.high, other.high).all()
)
@dataclass(repr=False)
class CategoricalBox(Box):
"""A box of discrete, categorical values."""
n: int
register = invertible_dict()
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox:
return deepcopy(self)
def __repr__(self):
return f"{self.__class__.__name__}(n={self.n})"
class DiscreteBox(CategoricalBox):
"""Deprecated version of :class:`CategoricalBox`."""
...
@dataclass(repr=False)
class BoxList(Box):
"""A box of discrete values."""
boxes: List
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> BoxList:
return BoxList([box.to(dest) for box in self.boxes])
def __iter__(self):
for elt in self.boxes:
yield elt
def __repr__(self):
return f"{self.__class__.__name__}(boxes={self.boxes})"
def __len__(self):
return len(self.boxes)
@staticmethod
def from_nvec(nvec: torch.Tensor):
if nvec.ndim == 0:
return CategoricalBox(nvec.item())
else:
return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)])
@dataclass(repr=False)
class BinaryBox(Box):
"""A box of n binary values."""
n: int
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
return deepcopy(self)
def __repr__(self):
return f"{self.__class__.__name__}(n={self.n})"
@dataclass(repr=False)
class TensorSpec:
"""Parent class of the tensor meta-data containers.
TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
or sometimes to simulate simple behaviors by generating random data within a defined space.
TensorSpecs are primarily used in environments to specify their input/output structure without needing to
execute the environment (or starting it). They can also be used to instantiate shared buffers to pass
data from worker to worker.
TensorSpecs are dataclasses that always share the following fields: `shape`, `space, `dtype` and `device`.
As such, TensorSpecs possess some common behavior with :class:`~torch.Tensor` and :class:`~tensordict.TensorDict`:
they can be reshaped, indexed, squeezed, unsqueezed, moved to another device etc.
Args:
shape (torch.Size): size of the tensor. The shape includes the batch dimensions as well as the feature
dimension. A negative shape (``-1``) means that the dimension has a variable number of elements.
space (Box): Box instance describing what kind of values can be expected.
device (torch.device): device of the tensor.
dtype (torch.dtype): dtype of the tensor.
.. note:: A spec can be constructed from a :class:`~tensordict.TensorDict` using the :func:`~torchrl.envs.utils.make_composite_from_td`
function. This function makes a low-assumption educated guess on the specs that may correspond to the input
tensordict and can help to build specs automatically without an in-depth knowledge of the `TensorSpec` API.
"""
shape: torch.Size
space: Union[None, Box]
device: torch.device | None = None
dtype: torch.dtype = torch.float
domain: str = ""
SPEC_HANDLED_FUNCTIONS = {}
@classmethod
def implements_for_spec(cls, torch_function: Callable) -> Callable:
"""Register a torch function override for TensorSpec."""
@wraps(torch_function)
def decorator(func):
cls.SPEC_HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
@property
def device(self) -> torch.device:
"""The device of the spec.
Only :class:`Composite` specs can have a ``None`` device. All leaves must have a non-null device.
"""
return self._device
@device.setter
def device(self, device: torch.device | None) -> None:
self._device = _make_ordinal_device(device)
def clear_device_(self) -> T:
"""A no-op for all leaf specs (which must have a device).
For :class:`Composite` specs, this method will erase the device.
"""
return self
def encode(
self,
val: np.ndarray | torch.Tensor | TensorDictBase,
*,
ignore_device: bool = False,
) -> torch.Tensor | TensorDictBase:
"""Encodes a value given the specified spec, and return the corresponding tensor.
This method is to be used in environments that return a value (eg, a numpy array) that can be
easily mapped to the TorchRL required domain.
If the value is already a tensor, the spec will not change its value and return it as-is.
Args:
val (np.ndarray or torch.Tensor): value to be encoded as tensor.
Keyword Args:
ignore_device (bool, optional): if ``True``, the spec device will
be ignored. This is used to group tensor casting within a call
to ``TensorDict(..., device="cuda")`` which is faster.
Returns:
torch.Tensor matching the required tensor specs.
"""
if not isinstance(val, torch.Tensor):
if isinstance(val, list):
if len(val) == 1:
# gym used to return lists of images since 0.26.0
# We convert these lists in np.array or take the first element
# if there is just one.
# See https://github.com/pytorch/rl/pull/403/commits/73d77d033152c61d96126ccd10a2817fecd285a1
val = val[0]
else:
val = np.array(val)
if isinstance(val, np.ndarray) and not all(
stride > 0 for stride in val.strides
):
val = val.copy()
if not ignore_device:
val = torch.as_tensor(val, device=self.device, dtype=self.dtype)
else:
val = torch.as_tensor(val, dtype=self.dtype)
if val.shape != self.shape:
# if val.shape[-len(self.shape) :] != self.shape:
# option 1: add a singleton dim at the end
if val.shape == self.shape and self.shape[-1] == 1:
val = val.unsqueeze(-1)
else:
try:
val = val.reshape(self.shape)
except Exception as err:
raise RuntimeError(
f"Shape mismatch: the value has shape {val.shape} which "
f"is incompatible with the spec shape {self.shape}."
) from err
if _CHECK_SPEC_ENCODE:
self.assert_is_in(val)
return val
def __ne__(self, other):
return not (self == other)
def __setattr__(self, key, value):
if key == "shape":
value = _size(value)
super().__setattr__(key, value)
def to_numpy(
self, val: torch.Tensor | TensorDictBase, safe: bool = None
) -> np.ndarray | dict:
"""Returns the ``np.ndarray`` correspondent of an input tensor.
This is intended to be the inverse operation of :meth:`.encode`.
Args:
val (torch.Tensor): tensor to be transformed_in to numpy.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
a np.ndarray.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return val.detach().cpu().numpy()
@property
def ndim(self) -> int:
"""Number of dimensions of the spec shape.
Shortcut for ``len(spec.shape)``.
"""
return self.ndimension()
def ndimension(self) -> int:
"""Number of dimensions of the spec shape.
Shortcut for ``len(spec.shape)``.
"""
return len(self.shape)
@property
def _safe_shape(self) -> torch.Size:
"""Returns a shape where all heterogeneous values are replaced by one (to be expandable)."""
return _size([int(v) if v >= 0 else 1 for v in self.shape])
@abc.abstractmethod
def index(
self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
) -> torch.Tensor | TensorDictBase:
"""Indexes the input tensor.
Args:
index (int, torch.Tensor, slice or list): index of the tensor
tensor_to_index: tensor to be indexed
Returns:
indexed tensor
"""
...
@overload
def expand(self, shape: torch.Size):
...
@abc.abstractmethod
def expand(self, *shape: int) -> T:
"""Returns a new Spec with the expanded shape.
Args:
*shape (tuple or iterable of int): the new shape of the Spec.
Must be broadcastable with the current shape:
its length must be at least as long as the current shape length,
and its last values must be compliant too; ie they can only differ
from it if the current dimension is a singleton.
"""
...
def squeeze(self, dim: int | None = None) -> T:
"""Returns a new Spec with all the dimensions of size ``1`` removed.
When ``dim`` is given, a squeeze operation is done only in that dimension.
Args:
dim (int or None): the dimension to apply the squeeze operation to
"""
shape = _squeezed_shape(self.shape, dim)
if shape is None:
return self
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
def unsqueeze(self, dim: int) -> T:
"""Returns a new Spec with one more singleton dimension (at the position indicated by ``dim``).
Args:
dim (int or None): the dimension to apply the unsqueeze operation to.
"""
shape = _unsqueezed_shape(self.shape, dim)
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
def make_neg_dim(self, dim: int) -> T:
"""Converts a specific dimension to ``-1``."""
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim > self.ndim - 1:
raise ValueError(f"dim={dim} is out of bound for ndim={self.ndim}")
self.shape = _size([s if i != dim else -1 for i, s in enumerate(self.shape)])
@overload
def reshape(self, shape) -> T:
...
def reshape(self, *shape) -> T:
"""Reshapes a ``TensorSpec``.
Check :func:`~torch.reshape` for more information on this method.
"""
if len(shape) == 1 and not isinstance(shape[0], int):
return self.reshape(*shape[0])
return self._reshape(shape)
view = reshape
@abc.abstractmethod
def _reshape(self, shape: torch.Size) -> T:
...
def unflatten(self, dim: int, sizes: Tuple[int]) -> T:
"""Unflattens a ``TensorSpec``.
Check :func:`~torch.unflatten` for more information on this method.
"""
return self._unflatten(dim, sizes)
def _unflatten(self, dim: int, sizes: Tuple[int]) -> T:
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self._reshape(shape)
def flatten(self, start_dim: int, end_dim: int) -> T:
"""Flattens a ``TensorSpec``.
Check :func:`~torch.flatten` for more information on this method.
"""
return self._flatten(start_dim, end_dim)
def _flatten(self, start_dim, end_dim):
shape = torch.zeros(self.shape, device="meta").flatten(start_dim, end_dim).shape
return self._reshape(shape)
@abc.abstractmethod
def _project(
self, val: torch.Tensor | TensorDictBase
) -> torch.Tensor | TensorDictBase:
raise NotImplementedError(type(self))
@abc.abstractmethod
def is_in(self, val: torch.Tensor | TensorDictBase) -> bool:
"""If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``.
More precisely, the ``is_in`` methods checks that the value ``val`` is within the limits defined by the ``space``
attribute (the box), and that the ``dtype``, ``device``, ``shape`` potentially other metadata match those
of the spec. If any of these checks fails, the ``is_in`` method will return ``False``.
Args:
val (torch.Tensor): value to be checked.
Returns:
boolean indicating if values belongs to the TensorSpec box.
"""
...
def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
"""If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``.
See :meth:`~.is_in` for more information.
"""
return self.is_in(item)
@abc.abstractmethod
def enumerate(self) -> Any:
"""Returns all the samples that can be obtained from the TensorSpec.
The samples will be stacked along the first dimension.
This method is only implemented for discrete specs.
"""
...
def project(
self, val: torch.Tensor | TensorDictBase
) -> torch.Tensor | TensorDictBase:
"""If the input tensor is not in the TensorSpec box, it maps it back to it given some defined heuristic.
Args:
val (torch.Tensor): tensor to be mapped to the box.
Returns:
a torch.Tensor belonging to the TensorSpec box.
"""
if not self.is_in(val):
return self._project(val)
return val
def assert_is_in(self, value: torch.Tensor) -> None:
"""Asserts whether a tensor belongs to the box, and raises an exception otherwise.
Args:
value (torch.Tensor): value to be checked.
"""
if not self.is_in(value):
raise AssertionError(
f"Encoding failed because value is not in space. "
f"Consider calling project(val) first. value was = {value} "
f"and spec was {self}."
)
def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None:
"""Checks the input value ``dtype`` against the ``TensorSpec`` ``dtype`` and raises an exception if they don't match.
Args:
value (torch.Tensor): tensor whose dtype has to be checked.
key (str, optional): if the TensorSpec has keys, the value
dtype will be checked against the spec pointed by the
indicated key.
"""
if value.dtype is not self.dtype:
raise TypeError(
f"value.dtype={value.dtype} but"
f" {self.__class__.__name__}.dtype={self.dtype}"
)
@abc.abstractmethod
def rand(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Returns a random tensor in the space defined by the spec.
The sampling will be done uniformly over the space, unless the box is unbounded in which case normal values
will be drawn.
Args:
shape (torch.Size): shape of the random tensor
Returns:
a random tensor sampled in the TensorSpec box.
"""
...
def sample(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Returns a random tensor in the space defined by the spec.
See :meth:`~.rand` for details.
"""
return self.rand(shape=shape)
def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Returns a zero-filled tensor in the box.
.. note:: Even though there is no guarantee that ``0`` belongs to the spec domain,
this method will not raise an exception when this condition is violated.
The primary use case of ``zero`` is to generate empty data buffers, not meaningful data.
Args:
shape (torch.Size): shape of the zero-tensor
Returns:
a zero-filled tensor sampled in the TensorSpec box.
"""
if shape is None:
shape = _size([])
return torch.zeros(
(*shape, *self._safe_shape), dtype=self.dtype, device=self.device
)
def zeros(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Proxy to :meth:`~.zero`."""
return self.zero(shape=shape)
def one(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Returns a one-filled tensor in the box.
.. note:: Even though there is no guarantee that ``1`` belongs to the spec domain,
this method will not raise an exception when this condition is violated.
The primary use case of ``one`` is to generate empty data buffers, not meaningful data.
Args:
shape (torch.Size): shape of the one-tensor
Returns:
a one-filled tensor sampled in the TensorSpec box.
"""
if self.dtype == torch.bool:
return ~self.zero(shape=shape)
return self.zero(shape) + 1
def ones(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Proxy to :meth:`~.one`."""
return self.one(shape=shape)
@abc.abstractmethod
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec":
"""Casts a TensorSpec to a device or a dtype.
Returns the same spec if no change is made.
"""
...
def cpu(self):
"""Casts the TensorSpec to 'cpu' device."""
return self.to("cpu")
def cuda(self, device=None):
"""Casts the TensorSpec to 'cuda' device."""
if device is None:
return self.to("cuda")
return self.to(f"cuda:{device}")
@abc.abstractmethod
def clone(self) -> "TensorSpec":
"""Creates a copy of the TensorSpec."""
...
def __repr__(self):
shape_str = indent("shape=" + str(self.shape), " " * 4)
space_str = indent("space=" + str(self.space), " " * 4)
device_str = indent("device=" + str(self.device), " " * 4)
dtype_str = indent("dtype=" + str(self.dtype), " " * 4)
domain_str = indent("domain=" + str(self.domain), " " * 4)
sub_string = ",\n".join(
[shape_str, space_str, device_str, dtype_str, domain_str]
)
string = f"{self.__class__.__name__}(\n{sub_string})"
return string
@classmethod
def __torch_function__(
cls,
func: Callable,