-
Notifications
You must be signed in to change notification settings - Fork 315
/
cql.py
1275 lines (1144 loc) · 51.8 KB
/
cql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import math
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey, unravel_key
from torch import Tensor
from torchrl.data.tensor_specs import Composite
from torchrl.data.utils import _find_action_space
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, QValueActor
from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
_vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
class CQLLoss(LossModule):
"""TorchRL implementation of the continuous CQL loss.
Presented in "Conservative Q-Learning for Offline Reinforcement Learning" https://arxiv.org/abs/2006.04779
Args:
actor_network (ProbabilisticActor): stochastic actor
qvalue_network (TensorDictModule or list of TensorDictModule): Q(s, a) parametric model.
This module typically outputs a ``"state_action_value"`` entry.
If a single instance of `qvalue_network` is provided, it will be duplicated ``N``
times (where ``N=2`` for this loss). If a list of modules is passed, their
parameters will be stacked unless they share the same identity (in which case
the original parameter will be expanded).
.. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters
and all the parameters will be considered as untied.
Keyword args:
loss_function (str, optional): loss function to be used with
the value function loss. Default is `"smooth_l1"`.
alpha_init (float, optional): initial entropy multiplier.
Default is 1.0.
min_alpha (float, optional): min value of alpha.
Default is None (no minimum value).
max_alpha (float, optional): max value of alpha.
Default is None (no maximum value).
action_spec (TensorSpec, optional): the action tensor spec. If not provided
and the target entropy is ``"auto"``, it will be retrieved from
the actor.
fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
initial value. Otherwise, alpha will be optimized to
match the 'target_entropy' value.
Default is ``False``.
target_entropy (float or str, optional): Target entropy for the
stochastic policy. Default is "auto", where target entropy is
computed as :obj:`-prod(n_actions)`.
delay_actor (bool, optional): Whether to separate the target actor
networks from the actor networks used for data collection.
Default is ``False``.
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used for data collection.
Default is ``True``.
gamma (float, optional): Discount factor. Default is ``None``.
temperature (float, optional): CQL temperature. Default is 1.0.
min_q_weight (float, optional): Minimum Q weight. Default is 1.0.
max_q_backup (bool, optional): Whether to use the max-min Q backup.
Default is ``False``.
deterministic_backup (bool, optional): Whether to use the deterministic. Default is ``True``.
num_random (int, optional): Number of random actions to sample for the CQL loss.
Default is 10.
with_lagrange (bool, optional): Whether to use the Lagrange multiplier.
Default is ``False``.
lagrange_thresh (float, optional): Lagrange threshold. Default is 0.0.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.cql import CQLLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = CQLLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
fields={
alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network.
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "loss_alpha", "loss_alpha_prime", "alpha", "entropy"]``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.cql import CQLLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = CQLLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_actor.backward()
The output keys can also be filtered using the :meth:`CQLLoss.select_out_keys`
method.
Examples:
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
>>> loss_actor.backward()
"""
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
Attributes:
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"advantage"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Defaults to ``"state_action_value"``.
log_prob (NestedKey): The input tensordict key where the log probability is expected.
Defaults to ``"_log_prob"``.
pred_q1 (NestedKey): The input tensordict key where the predicted Q1 values are expected.
Defaults to ``"pred_q1"``.
pred_q2 (NestedKey): The input tensordict key where the predicted Q2 values are expected.
Defaults to ``"pred_q2"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
cql_q1_loss (NestedKey): The input tensordict key where the CQL Q1 loss is expected.
Defaults to ``"cql_q1_loss"``.
cql_q2_loss (NestedKey): The input tensordict key where the CQL Q2 loss is expected.
Defaults to ``"cql_q2_loss"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Defaults to ``"reward"``.
done (NestedKey): The input tensordict key where the done flag is expected.
Defaults to ``"done"``.
terminated (NestedKey): The input tensordict key where the terminated flag is expected.
Defaults to ``"terminated"``.
"""
action: NestedKey = "action"
value: NestedKey = "state_value"
state_action_value: NestedKey = "state_action_value"
log_prob: NestedKey = "_log_prob"
pred_q1: NestedKey = "pred_q1"
pred_q2: NestedKey = "pred_q2"
priority: NestedKey = "td_error"
cql_q1_loss: NestedKey = "cql_q1_loss"
cql_q2_loss: NestedKey = "cql_q2_loss"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
actor_network: TensorDictModule
qvalue_network: TensorDictModule
actor_network_params: TensorDictParams
qvalue_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_qvalue_network_params: TensorDictParams
def __init__(
self,
actor_network: ProbabilisticActor,
qvalue_network: TensorDictModule | List[TensorDictModule],
*,
loss_function: str = "smooth_l1",
alpha_init: float = 1.0,
min_alpha: float = None,
max_alpha: float = None,
action_spec=None,
fixed_alpha: bool = False,
target_entropy: Union[str, float] = "auto",
delay_actor: bool = False,
delay_qvalue: bool = True,
gamma: float = None,
temperature: float = 1.0,
min_q_weight: float = 1.0,
max_q_backup: bool = False,
deterministic_backup: bool = True,
num_random: int = 10,
with_lagrange: bool = False,
lagrange_thresh: float = 0.0,
reduction: str = None,
) -> None:
self._out_keys = None
if reduction is None:
reduction = "mean"
super().__init__()
# Actor
self.delay_actor = delay_actor
self.convert_to_functional(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
)
# Q value
self.delay_qvalue = delay_qvalue
self.num_qvalue_nets = 2
self.convert_to_functional(
qvalue_network,
"qvalue_network",
self.num_qvalue_nets,
create_target_params=self.delay_qvalue,
compare_against=list(actor_network.parameters()),
)
self.loss_function = loss_function
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
if max_alpha == 0:
raise ValueError("max_alpha must be either None or greater than 0.")
max_alpha = max_alpha if max_alpha else 1e9
if min_alpha:
self.register_buffer(
"min_log_alpha", torch.tensor(min_alpha, device=device).log()
)
else:
self.min_log_alpha = None
if max_alpha:
self.register_buffer(
"max_log_alpha", torch.tensor(max_alpha, device=device).log()
)
else:
self.max_log_alpha = None
self.fixed_alpha = fixed_alpha
if fixed_alpha:
self.register_buffer(
"log_alpha", torch.tensor(math.log(alpha_init), device=device)
)
else:
self.register_parameter(
"log_alpha",
torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
)
self._target_entropy = target_entropy
self._action_spec = action_spec
self.target_entropy_buffer = None
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.temperature = temperature
self.min_q_weight = min_q_weight
self.max_q_backup = max_q_backup
self.deterministic_backup = deterministic_backup
self.num_random = num_random
self.with_lagrange = with_lagrange
if self.with_lagrange:
self.target_action_gap = lagrange_thresh
self.register_parameter(
"log_alpha_prime",
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)
self._make_vmap()
self.reduction = reduction
def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
if target_entropy is None:
delattr(self, "target_entropy_buffer")
target_entropy = self._target_entropy
action_spec = self._action_spec
actor_network = self.actor_network
device = next(self.parameters()).device
if target_entropy == "auto":
action_spec = (
action_spec
if action_spec is not None
else getattr(actor_network, "spec", None)
)
if action_spec is None:
raise RuntimeError(
"Cannot infer the dimensionality of the action. Consider providing "
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
if not isinstance(action_spec, Composite):
action_spec = Composite({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):
action_container_shape = action_spec[
self.tensor_keys.action[:-1]
].shape
else:
action_container_shape = action_spec.shape
target_entropy = -float(
action_spec[self.tensor_keys.action]
.shape[len(action_container_shape) :]
.numel()
)
self.register_buffer(
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
)
return self.target_entropy_buffer
return target_entropy
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
terminated=self.tensor_keys.terminated,
)
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
value_type = self.default_value_estimator
self.value_type = value_type
# we will take care of computing the next value inside this module
value_net = None
hp = dict(default_value_kwargs(value_type))
hp.update(hyperparams)
if value_type is ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
**hp,
value_network=value_net,
)
elif value_type is ValueEstimators.TD0:
self._value_estimator = TD0Estimator(
**hp,
value_network=value_net,
)
elif value_type is ValueEstimators.GAE:
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
)
elif value_type is ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(
**hp,
value_network=value_net,
)
else:
raise NotImplementedError(f"Unknown value type {value_type}")
tensor_keys = {
"value_target": "value_target",
"value": self.tensor_keys.value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,
}
self._value_estimator.set_keys(**tensor_keys)
@property
def in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.qvalue_network.in_keys,
]
return list(set(keys))
@property
def out_keys(self):
if self._out_keys is None:
keys = [
"loss_actor",
"loss_actor_bc",
"loss_qvalue",
"loss_cql",
"loss_alpha",
"alpha",
"entropy",
]
if self.with_lagrange:
keys.append("loss_alpha_prime")
self._out_keys = keys
return self._out_keys
@out_keys.setter
def out_keys(self, values):
self._out_keys = values
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
q_loss, metadata = self.q_loss(tensordict)
cql_loss, cql_metadata = self.cql_loss(tensordict)
if self.with_lagrange:
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict)
metadata.update(alpha_prime_metadata)
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict)
loss_actor, actor_metadata = self.actor_loss(tensordict)
loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
metadata.update(bc_metadata)
metadata.update(cql_metadata)
metadata.update(actor_metadata)
metadata.update(alpha_metadata)
tensordict.set(
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
)
out = {
"loss_actor": loss_actor,
"loss_actor_bc": loss_actor_bc,
"loss_qvalue": q_loss,
"loss_cql": cql_loss,
"loss_alpha": loss_alpha,
"alpha": self._alpha,
"entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
}
if self.with_lagrange:
out["loss_alpha_prime"] = alpha_prime_loss.mean()
return TensorDict(out, [])
@property
@_cache_values
def _cached_detach_qvalue_params(self):
return self.qvalue_network_params.detach()
def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(
tensordict,
)
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm)
bc_log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action))
bc_actor_loss = self._alpha * log_prob - bc_log_prob
bc_actor_loss = _reduce(bc_actor_loss, reduction=self.reduction)
metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
return bc_actor_loss, metadata
def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(
tensordict,
)
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
if td_q is tensordict:
raise RuntimeError
td_q.set(self.tensor_keys.action, a_reparm)
td_q = self._vmap_qvalue_networkN0(
td_q,
self._cached_detach_qvalue_params,
)
min_q_logprob = (
td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
)
if log_prob.shape != min_q_logprob.shape:
raise RuntimeError(
f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}"
)
metadata = {}
metadata[self.tensor_keys.log_prob] = log_prob.detach()
actor_loss = self._alpha * log_prob - min_q_logprob
actor_loss = _reduce(actor_loss, reduction=self.reduction)
return actor_loss, metadata
def _get_policy_actions(self, data, actor_params, num_actions=10):
batch_size = data.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions]
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]
def filter_and_repeat(name, x):
if name in in_keys:
return x.repeat_interleave(num_actions, dim=data.ndim - 1)
tensordict = data.named_apply(
filter_and_repeat, batch_size=batch_size, filter_empty=True
)
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
self.actor_network
):
dist = self.actor_network.get_dist(tensordict)
action = dist.rsample()
tensordict.set(self.tensor_keys.action, action)
sample_log_prob = dist.log_prob(action)
# tensordict.del_("loc")
# tensordict.del_("scale")
return (
tensordict.select(
*self.actor_network.in_keys, self.tensor_keys.action, strict=False
),
sample_log_prob,
)
def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
tensordict = tensordict.clone(False)
# get actions and log-probs
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
self.actor_network
):
next_tensordict = tensordict.get("next").clone(False)
next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = next_dist.log_prob(next_action)
# get q-values
if not self.max_q_backup:
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params
)
next_state_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
).min(0)[0]
if (
next_state_value.shape[-len(next_sample_log_prob.shape) :]
!= next_sample_log_prob.shape
):
next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
if not self.deterministic_backup:
next_state_value = next_state_value - _alpha * next_sample_log_prob
if self.max_q_backup:
next_tensordict, _ = self._get_policy_actions(
tensordict.get("next").copy(),
actor_params,
num_actions=self.num_random,
)
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params
)
state_action_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
)
# take max over actions
state_action_value = state_action_value.reshape(
torch.Size(
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
)
).max(-2)[0]
# take min over qvalue nets
next_state_value = state_action_value.min(0)[0]
tensordict.set(
("next", self.value_estimator.tensor_keys.value), next_state_value
)
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
return target_value
def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
target_value = self._get_value_v(
tensordict.copy(),
self._alpha,
self.actor_network_params,
self.target_qvalue_network_params,
)
tensordict_pred_q = tensordict.select(
*self.qvalue_network.in_keys, strict=False
)
q_pred = self._vmap_qvalue_networkN0(
tensordict_pred_q, self.qvalue_network_params
).get(self.tensor_keys.state_action_value)
# write pred values in tensordict for cql loss
tensordict.set(self.tensor_keys.pred_q1, q_pred[0])
tensordict.set(self.tensor_keys.pred_q2, q_pred[1])
q_pred = q_pred.squeeze(-1)
loss_qval = distance_loss(
q_pred,
target_value.expand_as(q_pred),
loss_function=self.loss_function,
).sum(0)
loss_qval = _reduce(loss_qval, reduction=self.reduction)
td_error = (q_pred - target_value).pow(2)
metadata = {"td_error": td_error.detach()}
return loss_qval, metadata
def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
pred_q1 = tensordict.get(self.tensor_keys.pred_q1)
pred_q2 = tensordict.get(self.tensor_keys.pred_q2)
if pred_q1 is None:
raise KeyError(
f"Couldn't find the pred_q1 with key {self.tensor_keys.pred_q1} in the input tensordict. "
"This could be caused by calling cql_loss method before q_loss method."
)
if pred_q2 is None:
raise KeyError(
f"Couldn't find the pred_q2 with key {self.tensor_keys.pred_q2} in the input tensordict. "
"This could be caused by calling cql_loss method before q_loss method."
)
random_actions_tensor = pred_q1.new_empty(
(
*tensordict.shape[:-1],
tensordict.shape[-1] * self.num_random,
tensordict[self.tensor_keys.action].shape[-1],
)
).uniform_(-1, 1)
curr_actions_td, curr_log_pis = self._get_policy_actions(
tensordict.copy(),
self.actor_network_params,
num_actions=self.num_random,
)
new_curr_actions_td, new_log_pis = self._get_policy_actions(
tensordict.get("next").copy(),
self.actor_network_params,
num_actions=self.num_random,
)
# process all in one forward pass
# stack qvalue params
qvalue_params = torch.cat(
[
self.qvalue_network_params,
self.qvalue_network_params,
self.qvalue_network_params,
],
0,
)
# select and stack input params
# q value random action
tensordict_q_random = tensordict.select(
*self.actor_network.in_keys, strict=False
)
batch_size = tensordict_q_random.batch_size
batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]
def filter_and_repeat(name, x):
if name in in_keys:
return x.repeat_interleave(
self.num_random, dim=tensordict_q_random.ndim - 1
)
tensordict_q_random = tensordict_q_random.named_apply(
filter_and_repeat,
batch_size=batch_size,
filter_empty=True,
)
tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor)
cql_tensordict = torch.cat(
[
tensordict_q_random.expand(
self.num_qvalue_nets, *curr_actions_td.batch_size
),
curr_actions_td.expand(
self.num_qvalue_nets, *curr_actions_td.batch_size
),
new_curr_actions_td.expand(
self.num_qvalue_nets, *curr_actions_td.batch_size
),
],
0,
)
cql_tensordict = cql_tensordict.contiguous()
cql_tensordict_expand = self._vmap_qvalue_network00(
cql_tensordict, qvalue_params
)
# get q values
state_action_value = cql_tensordict_expand.get(
self.tensor_keys.state_action_value
)
# split q values
(q_random, q_curr, q_new,) = state_action_value.split(
[
self.num_qvalue_nets,
self.num_qvalue_nets,
self.num_qvalue_nets,
],
dim=0,
)
# importance sammpled version
random_density = np.log(
0.5 ** curr_actions_td[self.tensor_keys.action].shape[-1]
)
cat_q1 = torch.cat(
[
q_random[0] - random_density,
q_new[0] - new_log_pis.detach().unsqueeze(-1),
q_curr[0] - curr_log_pis.detach().unsqueeze(-1),
],
-1,
)
cat_q2 = torch.cat(
[
q_random[1] - random_density,
q_new[1] - new_log_pis.detach().unsqueeze(-1),
q_curr[1] - curr_log_pis.detach().unsqueeze(-1),
],
-1,
)
min_qf1_loss = (
torch.logsumexp(cat_q1 / self.temperature, dim=-1)
* self.min_q_weight
* self.temperature
)
min_qf2_loss = (
torch.logsumexp(cat_q2 / self.temperature, dim=-1)
* self.min_q_weight
* self.temperature
)
# Subtract the log likelihood of data
cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight
cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight
# write cql losses in tensordict for alpha prime loss
tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
tensordict.set(self.tensor_keys.cql_q2_loss, cql_q2_loss)
cql_q_loss = (cql_q1_loss + cql_q2_loss).mean(-1)
cql_q_loss = _reduce(cql_q_loss, reduction=self.reduction)
return cql_q_loss, {}
def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
cql_q1_loss = tensordict.get(self.tensor_keys.cql_q1_loss)
cql_q2_loss = tensordict.get(self.tensor_keys.cql_q2_loss)
if cql_q1_loss is None:
raise KeyError(
f"Couldn't find the cql_q1_loss with key {self.tensor_keys.cql_q1_loss} in the input tensordict. "
"This could be caused by calling alpha_prime_loss method before cql_loss method."
)
if cql_q2_loss is None:
raise KeyError(
f"Couldn't find the cql_q2_loss with key {self.tensor_keys.cql_q2_loss} in the input tensordict. "
"This could be caused by calling alpha_prime_loss method before cql_loss method."
)
alpha_prime = torch.clamp_max(self.log_alpha_prime.exp(), max=1000000.0)
min_qf1_loss = alpha_prime * (cql_q1_loss.mean() - self.target_action_gap)
min_qf2_loss = alpha_prime * (cql_q2_loss.mean() - self.target_action_gap)
alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction)
return alpha_prime_loss, {}
def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
log_pi = tensordict.get(self.tensor_keys.log_prob)
if self.target_entropy is not None:
# we can compute this loss even if log_alpha is not a parameter
alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy)
else:
# placeholder
alpha_loss = torch.zeros_like(log_pi)
alpha_loss = _reduce(alpha_loss, reduction=self.reduction)
return alpha_loss, {}
@property
def _alpha(self):
if self.min_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
return alpha
class DiscreteCQLLoss(LossModule):
"""TorchRL implementation of the discrete CQL loss.
This class implements the discrete conservative Q-learning (CQL) loss function, as presented in the paper
"Conservative Q-Learning for Offline Reinforcement Learning" (https://arxiv.org/abs/2006.04779).
Args:
value_network (Union[QValueActor, nn.Module]): The Q-value network used to estimate state-action values.
Keyword Args:
loss_function (Optional[str]): The distance function used to calculate the distance between the predicted
Q-values and the target Q-values. Defaults to ``l2``.
delay_value (bool): Whether to separate the target Q value
networks from the Q value networks used for data collection.
Default is ``True``.
gamma (float, optional): Discount factor. Default is ``None``.
action_space: The action space of the environment. If None, it is inferred from the value network.
Defaults to None.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
Examples:
>>> from torchrl.modules import MLP, QValueActor
>>> from torchrl.data import OneHot
>>> from torchrl.objectives import DiscreteCQLLoss
>>> n_obs, n_act = 4, 3
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
>>> spec = OneHot(n_act)
>>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
>>> loss = DiscreteCQLLoss(actor, action_space=spec)
>>> batch = [10,]
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": spec.rand(batch),
... ("next", "observation"): torch.randn(*batch, n_obs),
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1)
... }, batch)
>>> loss(data)
TensorDict(
fields={
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``,
and a single loss value is returned.
Examples:
>>> from torchrl.objectives import DiscreteCQLLoss
>>> from torchrl.data import OneHot
>>> from torch import nn
>>> import torch
>>> n_obs = 3
>>> n_action = 4
>>> action_spec = OneHot(n_action)
>>> value_network = nn.Linear(n_obs, n_action) # a simple value model
>>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec)
>>> # define data
>>> observation = torch.randn(n_obs)
>>> next_observation = torch.randn(n_obs)
>>> action = action_spec.rand()
>>> next_reward = torch.randn(1)
>>> next_done = torch.zeros(1, dtype=torch.bool)
>>> next_terminated = torch.zeros(1, dtype=torch.bool)
>>> loss_val = dcql_loss(
... observation=observation,
... next_observation=next_observation,
... next_reward=next_reward,
... next_done=next_done,
... next_terminated=next_terminated,
... action=action)
"""
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values.
Attributes:
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the chosen action value is expected.
Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``.
action_value (NestedKey): The input tensordict key where the action value is expected.