-
Notifications
You must be signed in to change notification settings - Fork 321
/
actors.py
2041 lines (1854 loc) · 87.4 KB
/
actors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Optional, Sequence, Tuple, Union
import torch
from tensordict import TensorDictBase, unravel_key
from tensordict.nn import (
dispatch,
TensorDictModule,
TensorDictModuleBase,
TensorDictModuleWrapper,
TensorDictSequential,
)
from tensordict.utils import NestedKey
from torch import nn
from torch.distributions import Categorical
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.data.utils import _process_action_space_spec
from torchrl.modules.models.models import DistributionalDQNnet
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.modules.tensordict_module.probabilistic import (
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from torchrl.modules.tensordict_module.sequence import SafeSequential
class Actor(SafeModule):
"""General class for deterministic actors in RL.
The Actor class comes with default values for the out_keys (["action"])
and if the spec is provided but not as a CompositeSpec object, it will be
automatically translated into :obj:`spec = CompositeSpec(action=spec)`
Args:
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
the output parameter space.
in_keys (iterable of str, optional): keys to be read from input
tensordict and passed to the module. If it
contains more than one element, the values will be passed in the
order given by the in_keys iterable.
Defaults to ``["observation"]``.
out_keys (iterable of str): keys to be written to the input tensordict.
The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a
key avoid writing tensor to output.
Defaults to ``["action"]``.
spec (TensorSpec, optional): Keyword-only argument.
Specs of the output tensor. If the module
outputs multiple output tensors,
spec characterize the space of the first output tensor.
safe (bool): Keyword-only argument.
If ``True``, the value of the output is checked against the
input spec. Out-of-domain sampling can
occur because of exploration policies or numerical under/overflow
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import Actor
>>> torch.manual_seed(0)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> action_spec = UnboundedContinuousTensorSpec(4)
>>> module = torch.nn.Linear(4, 4)
>>> td_module = Actor(
... module=module,
... spec=action_spec,
... )
>>> td_module(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> print(td.get("action"))
tensor([[-1.3635, -0.0340, 0.1476, -1.3911],
[-0.1664, 0.5455, 0.2247, -0.4583],
[-0.2916, 0.2160, 0.5337, -0.5193]], grad_fn=<AddmmBackward0>)
"""
def __init__(
self,
module: nn.Module,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
*,
spec: Optional[TensorSpec] = None,
**kwargs,
):
if in_keys is None:
in_keys = ["observation"]
if out_keys is None:
out_keys = ["action"]
if (
"action" in out_keys
and spec is not None
and not isinstance(spec, CompositeSpec)
):
spec = CompositeSpec(action=spec)
super().__init__(
module,
in_keys=in_keys,
out_keys=out_keys,
spec=spec,
**kwargs,
)
class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
"""General class for probabilistic actors in RL.
The Actor class comes with default values for the out_keys (["action"])
and if the spec is provided but not as a CompositeSpec object, it will be
automatically translated into :obj:`spec = CompositeSpec(action=spec)`
Args:
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
the output parameter space.
in_keys (str or iterable of str or dict): key(s) that will be read from the
input TensorDict and used to build the distribution. Importantly, if it's an
iterable of string or a string, those keys must match the keywords used by
the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for
the Normal distribution and similar. If in_keys is a dictionary,, the keys
are the keys of the distribution and the values are the keys in the
tensordict that will get match to the corresponding distribution keys.
out_keys (str or iterable of str): keys where the sampled values will be
written. Importantly, if these keys are found in the input TensorDict, the
sampling step will be skipped.
spec (TensorSpec, optional): keyword-only argument containing the specs
of the output tensor. If the module outputs multiple output tensors,
spec characterize the space of the first output tensor.
safe (bool): keyword-only argument. if ``True``, the value of the output is checked against the
input spec. Out-of-domain sampling can
occur because of exploration policies or numerical under/overflow
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument.
Default method to be used to retrieve
the output value. Should be one of: 'mode', 'median', 'mean' or 'random'
(in which case the value is sampled randomly from the distribution). Default
is 'mode'.
Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will
first look for the interaction mode dictated by the `interaction_typ()`
global function. If this returns `None` (its default value), then the
`default_interaction_type` of the `ProbabilisticTDModule` instance will be
used. Note that DataCollector instances will use `set_interaction_type` to
:class:`tensordict.nn.InteractionType.RANDOM` by default.
distribution_class (Type, optional): keyword-only argument.
A :class:`torch.distributions.Distribution` class to
be used for sampling.
Default is :class:`tensordict.nn.distributions.Delta`.
distribution_kwargs (dict, optional): keyword-only argument.
Keyword-argument pairs to be passed to the distribution.
return_log_prob (bool, optional): keyword-only argument.
If ``True``, the log-probability of the
distribution sample will be written in the tensordict with the key
`'sample_log_prob'`. Default is ``False``.
cache_dist (bool, optional): keyword-only argument.
EXPERIMENTAL: if ``True``, the parameters of the
distribution (i.e. the output of the module) will be written to the
tensordict along with the sample. Those parameters can be used to re-compute
the original distribution later on (e.g. to compute the divergence between
the distribution used to sample the action and the updated distribution in
PPO). Default is ``False``.
n_empirical_estimate (int, optional): keyword-only argument.
Number of samples to compute the empirical
mean when it is not available. Defaults to 1000.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, make_functional
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> action_spec = BoundedTensorSpec(shape=torch.Size([4]),
... low=-1, high=1)
>>> module = NormalParamWrapper(torch.nn.Linear(4, 8))
>>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
>>> td_module = ProbabilisticActor(
... module=tensordict_module,
... spec=action_spec,
... in_keys=["loc", "scale"],
... distribution_class=TanhNormal,
... )
>>> params = make_functional(td_module)
>>> td = td_module(td, params=params)
>>> td
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
"""
def __init__(
self,
module: TensorDictModule,
in_keys: Union[NestedKey, Sequence[NestedKey]],
out_keys: Optional[Sequence[NestedKey]] = None,
*,
spec: Optional[TensorSpec] = None,
**kwargs,
):
if out_keys is None:
out_keys = ["action"]
if (
len(out_keys) == 1
and spec is not None
and not isinstance(spec, CompositeSpec)
):
spec = CompositeSpec({out_keys[0]: spec})
super().__init__(
module,
SafeProbabilisticModule(
in_keys=in_keys, out_keys=out_keys, spec=spec, **kwargs
),
)
class ValueOperator(TensorDictModule):
"""General class for value functions in RL.
The ValueOperator class comes with default values for the in_keys and
out_keys arguments (["observation"] and ["state_value"] or
["state_action_value"], respectively and depending on whether the "action"
key is part of the in_keys list).
Args:
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
the output parameter space.
in_keys (iterable of str, optional): keys to be read from input
tensordict and passed to the module. If it
contains more than one element, the values will be passed in the
order given by the in_keys iterable.
Defaults to ``["observation"]``.
out_keys (iterable of str): keys to be written to the input tensordict.
The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a
key avoid writing tensor to output.
Defaults to ``["action"]``.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> from torch import nn
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import ValueOperator
>>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,])
>>> class CustomModule(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = torch.nn.Linear(6, 1)
... def forward(self, obs, action):
... return self.linear(torch.cat([obs, action], -1))
>>> module = CustomModule()
>>> td_module = ValueOperator(
... in_keys=["observation", "action"], module=module
... )
>>> params = make_functional(td_module)
>>> td = td_module(td, params=params)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
"""
def __init__(
self,
module: nn.Module,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
) -> None:
if in_keys is None:
in_keys = ["observation"]
if out_keys is None:
out_keys = (
["state_value"] if "action" not in in_keys else ["state_action_value"]
)
super().__init__(
module=module,
in_keys=in_keys,
out_keys=out_keys,
)
class QValueModule(TensorDictModuleBase):
"""Q-Value TensorDictModule for Q-value policies.
This module processes a tensor containing action value into is argmax
component (i.e. the resulting greedy action), following a given
action space (one-hot, binary or categorical).
It works with both tensordict and regular tensors.
Args:
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
action_value_key (str or tuple of str, optional): The input key
representing the action value. Defaults to ``"action_value"``.
action_mask_key (str or tuple of str, optional): The input key
representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
out_keys (list of str or tuple of str, optional): The output keys
representing the actions, action values and chosen action value.
Defaults to ``["action", "action_value", "chosen_action_value"]``.
var_nums (int, optional): if ``action_space = "mult-one-hot"``,
this value represents the cardinality of each
action component.
spec (TensorSpec, optional): if provided, the specs of the action (and/or
other outputs). This is exclusive with ``action_space``, as the spec
conditions the action space.
safe (bool): if ``True``, the value of the output is checked against the
input spec. Out-of-domain sampling can
occur because of exploration policies or numerical under/overflow issues.
If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
Returns:
if the input is a single tensor, a triplet containing the chosen action,
the values and the value of the chose action is returned. If a tensordict
is provided, it is updated with these entries at the keys indicated by the
``out_keys`` field.
Examples:
>>> from tensordict import TensorDict
>>> action_space = "categorical"
>>> action_value_key = "my_action_value"
>>> actor = QValueModule(action_space, action_value_key=action_value_key)
>>> # This module works with both tensordict and regular tensors:
>>> value = torch.zeros(4)
>>> value[-1] = 1
>>> actor(my_action_value=value)
(tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
>>> actor(value)
(tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
>>> actor(TensorDict({action_value_key: value}, []))
TensorDict(
fields={
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
my_action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
def __init__(
self,
action_space: Optional[str],
action_value_key: Optional[NestedKey] = None,
action_mask_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
var_nums: Optional[int] = None,
spec: Optional[TensorSpec] = None,
safe: bool = False,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.var_nums = var_nums
self.action_func_mapping = {
"one_hot": self._one_hot,
"mult_one_hot": self._mult_one_hot,
"binary": self._binary,
"categorical": self._categorical,
}
self.action_value_func_mapping = {
"categorical": self._categorical_action_value,
}
if action_space not in self.action_func_mapping:
raise ValueError(
f"action_space must be one of {list(self.action_func_mapping.keys())}, got {action_space}"
)
if action_value_key is None:
action_value_key = "action_value"
self.action_mask_key = action_mask_key
in_keys = [action_value_key]
if self.action_mask_key is not None:
in_keys.append(self.action_mask_key)
self.in_keys = in_keys
if out_keys is None:
out_keys = ["action", action_value_key, "chosen_action_value"]
elif action_value_key not in out_keys:
raise RuntimeError(
f"Expected the action-value key to be '{action_value_key}' but got {out_keys[1]} instead."
)
self.out_keys = out_keys
action_key = out_keys[0]
if not isinstance(spec, CompositeSpec):
spec = CompositeSpec({action_key: spec})
super().__init__()
self.register_spec(safe=safe, spec=spec)
register_spec = SafeModule.register_spec
@property
def spec(self) -> CompositeSpec:
return self._spec
@spec.setter
def spec(self, spec: CompositeSpec) -> None:
if not isinstance(spec, CompositeSpec):
raise RuntimeError(
f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance."
)
self._spec = spec
@property
def action_value_key(self):
return self.in_keys[0]
@dispatch(auto_batch_size=False)
def forward(self, tensordict: torch.Tensor) -> TensorDictBase:
action_values = tensordict.get(self.action_value_key, None)
if action_values is None:
raise KeyError(
f"Action value key {self.action_value_key} not found in {tensordict}."
)
if self.action_mask_key is not None:
action_mask = tensordict.get(self.action_mask_key, None)
if action_mask is None:
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
action_values = torch.where(
action_mask, action_values, torch.finfo(action_values.dtype).min
)
action = self.action_func_mapping[self.action_space](action_values)
action_value_func = self.action_value_func_mapping.get(
self.action_space, self._default_action_value
)
chosen_action_value = action_value_func(action_values, action)
tensordict.update(
dict(zip(self.out_keys, (action, action_values, chosen_action_value)))
)
return tensordict
@staticmethod
def _one_hot(value: torch.Tensor) -> torch.Tensor:
out = (value == value.max(dim=-1, keepdim=True)[0]).to(torch.long)
return out
@staticmethod
def _categorical(value: torch.Tensor) -> torch.Tensor:
return torch.argmax(value, dim=-1).to(torch.long)
def _mult_one_hot(
self, value: torch.Tensor, support: torch.Tensor = None
) -> torch.Tensor:
if self.var_nums is None:
raise ValueError(
"var_nums must be provided to the constructor for multi one-hot action spaces."
)
values = value.split(self.var_nums, dim=-1)
return torch.cat(
[
self._one_hot(
_value,
)
for _value in values
],
-1,
)
@staticmethod
def _binary(value: torch.Tensor, support: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def _default_action_value(
values: torch.Tensor, action: torch.Tensor
) -> torch.Tensor:
return (action * values).sum(-1, True)
@staticmethod
def _categorical_action_value(
values: torch.Tensor, action: torch.Tensor
) -> torch.Tensor:
return values.gather(-1, action.unsqueeze(-1))
# if values.ndim == 1:
# return values[action].unsqueeze(-1)
# batch_size = values.size(0)
# return values[range(batch_size), action].unsqueeze(-1)
class DistributionalQValueModule(QValueModule):
"""Distributional Q-Value hook for Q-value policies.
This module processes a tensor containing action value logits into is argmax
component (i.e. the resulting greedy action), following a given
action space (one-hot, binary or categorical).
It works with both tensordict and regular tensors.
The input action value is expected to be the result of a log-softmax
operation.
For more details regarding Distributional DQN, refer to "A Distributional Perspective on Reinforcement Learning",
https://arxiv.org/pdf/1707.06887.pdf
Args:
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
support (torch.Tensor): support of the action values.
action_value_key (str or tuple of str, optional): The input key
representing the action value. Defaults to ``"action_value"``.
action_mask_key (str or tuple of str, optional): The input key
representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
out_keys (list of str or tuple of str, optional): The output keys
representing the actions and action values.
Defaults to ``["action", "action_value"]``.
var_nums (int, optional): if ``action_space = "mult-one-hot"``,
this value represents the cardinality of each
action component.
spec (TensorSpec, optional): if provided, the specs of the action (and/or
other outputs). This is exclusive with ``action_space``, as the spec
conditions the action space.
safe (bool): if ``True``, the value of the output is checked against the
input spec. Out-of-domain sampling can
occur because of exploration policies or numerical under/overflow issues.
If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
Examples:
>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> action_space = "categorical"
>>> action_value_key = "my_action_value"
>>> support = torch.tensor([-1, 0.0, 1.0]) # the action value is between -1 and 1
>>> actor = DistributionalQValueModule(action_space, support=support, action_value_key=action_value_key)
>>> # This module works with both tensordict and regular tensors:
>>> value = torch.full((3, 4), -100)
>>> # the first bin (-1) of the first action is high: there's a high chance that it has a low value
>>> value[0, 0] = 0
>>> # the second bin (0) of the second action is high: there's a high chance that it has an intermediate value
>>> value[1, 1] = 0
>>> # the third bin (0) of the thid action is high: there's a high chance that it has an high value
>>> value[2, 2] = 0
>>> actor(my_action_value=value)
(tensor(2), tensor([[ 0, -100, -100, -100],
[-100, 0, -100, -100],
[-100, -100, 0, -100]]))
>>> actor(value)
(tensor(2), tensor([[ 0, -100, -100, -100],
[-100, 0, -100, -100],
[-100, -100, 0, -100]]))
>>> actor(TensorDict({action_value_key: value}, []))
TensorDict(
fields={
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
my_action_value: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
def __init__(
self,
action_space: Optional[str],
support: torch.Tensor,
action_value_key: Optional[NestedKey] = None,
action_mask_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
var_nums: Optional[int] = None,
spec: TensorSpec = None,
safe: bool = False,
):
if action_value_key is None:
action_value_key = "action_value"
if out_keys is None:
out_keys = ["action", action_value_key]
super().__init__(
action_space=action_space,
action_value_key=action_value_key,
action_mask_key=action_mask_key,
out_keys=out_keys,
var_nums=var_nums,
spec=spec,
safe=safe,
)
self.register_buffer("support", support)
@dispatch(auto_batch_size=False)
def forward(self, tensordict: torch.Tensor) -> TensorDictBase:
action_values = tensordict.get(self.action_value_key, None)
if action_values is None:
raise KeyError(
f"Action value key {self.action_value_key} not found in {tensordict}."
)
if self.action_mask_key is not None:
action_mask = tensordict.get(self.action_mask_key, None)
if action_mask is None:
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
action_values = torch.where(
action_mask, action_values, torch.finfo(action_values.dtype).min
)
action = self.action_func_mapping[self.action_space](action_values)
tensordict.update(
dict(
zip(
self.out_keys,
(
action,
action_values,
),
)
)
)
return tensordict
def _support_expected(
self, log_softmax_values: torch.Tensor, support=None
) -> torch.Tensor:
if support is None:
support = self.support
support = support.to(log_softmax_values.device)
if log_softmax_values.shape[-2] != support.shape[-1]:
raise RuntimeError(
"Support length and number of atoms in module output should match, "
f"got self.support.shape={support.shape} and module(...).shape={log_softmax_values.shape}"
)
if (log_softmax_values > 0).any():
raise ValueError(
f"input to QValueHook must be log-softmax values (which are expected to be non-positive numbers). "
f"got a maximum value of {log_softmax_values.max():4.4f}"
)
return (log_softmax_values.exp() * support.unsqueeze(-1)).sum(-2)
def _one_hot(self, value: torch.Tensor, support=None) -> torch.Tensor:
if support is None:
support = self.support
if not isinstance(value, torch.Tensor):
raise TypeError(f"got value of type {value.__class__.__name__}")
if not isinstance(support, torch.Tensor):
raise TypeError(f"got support of type {support.__class__.__name__}")
value = self._support_expected(value)
out = (value == value.max(dim=-1, keepdim=True)[0]).to(torch.long)
return out
def _mult_one_hot(self, value: torch.Tensor, support=None) -> torch.Tensor:
if support is None:
support = self.support
values = value.split(self.var_nums, dim=-1)
return torch.cat(
[
self._one_hot(_value, _support)
for _value, _support in zip(values, support)
],
-1,
)
def _categorical(
self,
value: torch.Tensor,
) -> torch.Tensor:
value = self._support_expected(
value,
)
return torch.argmax(value, dim=-1).to(torch.long)
def _binary(self, value: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"'binary' is currently not supported for DistributionalQValueModule."
)
class QValueHook:
"""Q-Value hook for Q-value policies.
Given the output of a regular nn.Module, representing the values of the
different discrete actions available,
a QValueHook will transform these values into their argmax component (i.e.
the resulting greedy action).
Args:
action_space (str): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
var_nums (int, optional): if ``action_space = "mult-one-hot"``,
this value represents the cardinality of each
action component.
action_value_key (str or tuple of str, optional): to be used when hooked on
a TensorDictModule. The input key representing the action value. Defaults
to ``"action_value"``.
action_mask_key (str or tuple of str, optional): The input key
representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
out_keys (list of str or tuple of str, optional): to be used when hooked on
a TensorDictModule. The output keys representing the actions, action values
and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> module = nn.Linear(4, 4)
>>> hook = QValueHook("one_hot")
>>> module.register_forward_hook(hook)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
"""
def __init__(
self,
action_space: str,
var_nums: Optional[int] = None,
action_value_key: Optional[NestedKey] = None,
action_mask_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, _ = _process_action_space_spec(action_space, None)
self.qvalue_model = QValueModule(
action_space=action_space,
var_nums=var_nums,
action_value_key=action_value_key,
action_mask_key=action_mask_key,
out_keys=out_keys,
)
action_value_key = self.qvalue_model.in_keys[0]
if isinstance(action_value_key, tuple):
action_value_key = "_".join(action_value_key)
# uses "dispatch" to get and return tensors
self.action_value_key = action_value_key
def __call__(
self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
kwargs = {self.action_value_key: values}
return self.qvalue_model(**kwargs)
class DistributionalQValueHook(QValueHook):
"""Distributional Q-Value hook for Q-value policies.
Given the output of a mapping operator, representing the log-probability of the
different action value bin available,
a DistributionalQValueHook will transform these values into their argmax
component using the provided support.
For more details regarding Distributional DQN, refer to "A Distributional Perspective on Reinforcement Learning",
https://arxiv.org/pdf/1707.06887.pdf
Args:
action_space (str): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
action_value_key (str or tuple of str, optional): to be used when hooked on
a TensorDictModule. The input key representing the action value. Defaults
to ``"action_value"``.
action_mask_key (str or tuple of str, optional): The input key
representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
support (torch.Tensor): support of the action values.
var_nums (int, optional): if ``action_space = "mult-one-hot"``, this
value represents the cardinality of each
action component.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> class CustomDistributionalQval(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(4, nbins*4)
...
... def forward(self, x):
... return self.linear(x).view(-1, nbins, 4).log_softmax(-2)
...
>>> module = CustomDistributionalQval()
>>> params = make_functional(module)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
>>> module.register_forward_hook(hook)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> qvalue_actor(td, params=params)
>>> print(td)
TensorDict(
fields={
action: Tensor(torch.Size([5, 4]), dtype=torch.int64),
action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32),
observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
"""
def __init__(
self,
action_space: str,
support: torch.Tensor,
var_nums: Optional[int] = None,
action_value_key: Optional[NestedKey] = None,
action_mask_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, _ = _process_action_space_spec(action_space, None)
self.qvalue_model = DistributionalQValueModule(
action_space=action_space,
var_nums=var_nums,
support=support,
action_value_key=action_value_key,
action_mask_key=action_mask_key,
out_keys=out_keys,
)
action_value_key = self.qvalue_model.in_keys[0]
if isinstance(action_value_key, tuple):
action_value_key = "_".join(action_value_key)
# uses "dispatch" to get and return tensors
self.action_value_key = action_value_key
class QValueActor(SafeSequential):
"""A Q-Value actor class.
This class appends a :class:`~.QValueModule` after the input module
such that the action values are used to select an action.
Args:
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
the output parameter space. If the class provided is not compatible
with :class:`tensordict.nn.TensorDictModuleBase`, it will be
wrapped in a :class:`tensordict.nn.TensorDictModule` with
``in_keys`` indicated by the following keyword argument.
Keyword Args:
in_keys (iterable of str, optional): If the class provided is not
compatible with :class:`tensordict.nn.TensorDictModuleBase`, this
list of keys indicates what observations need to be passed to the
wrapped module to get the action values.
Defaults to ``["observation"]``.
spec (TensorSpec, optional): Keyword-only argument.
Specs of the output tensor. If the module
outputs multiple output tensors,
spec characterize the space of the first output tensor.
safe (bool): Keyword-only argument.
If ``True``, the value of the output is checked against the
input spec. Out-of-domain sampling can
occur because of exploration policies or numerical under/overflow
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
action_value_key (str or tuple of str, optional): if the input module
is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must
match one of its output keys. Otherwise, this string represents
the name of the action-value entry in the output tensordict.
action_mask_key (str or tuple of str, optional): The input key
representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
.. note::
``out_keys`` cannot be passed. If the module is a :class:`tensordict.nn.TensorDictModule`
instance, the out_keys will be updated accordingly. For regular
:class:`torch.nn.Module` instance, the triplet ``["action", action_value_key, "chosen_action_value"]``
will be used.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # with a regular nn.Module
>>> module = nn.Linear(4, 4)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
>>> # with a TensorDictModule
>>> td = TensorDict({'obs': torch.randn(5, 4)}, [5])
>>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"])
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
"""
def __init__(
self,
module,
*,
in_keys=None,
spec=None,
safe=False,
action_space: Optional[str] = None,
action_value_key=None,
action_mask_key: Optional[NestedKey] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.action_value_key = action_value_key
if action_value_key is None:
action_value_key = "action_value"