-
Notifications
You must be signed in to change notification settings - Fork 7
/
policy.py
1155 lines (962 loc) · 36.7 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import torch
from torch import nn
import pytorch_utils as ptu
from collections import OrderedDict
from typing import Union, Callable
import math
from torch.distributions import Distribution as TorchDistribution
from torch.distributions import Normal as TorchNormal
from torch.distributions import Independent as TorchIndependent
from torch.distributions import Bernoulli as TorchBernoulli
from torch.distributions import Categorical, OneHotCategorical, kl_divergence
from torch.nn import functional as F
import torchvision
import abc
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 0.001
class Distribution(TorchDistribution):
def sample_and_logprob(self):
s = self.sample()
log_p = self.log_prob(s)
return s, log_p
def rsample_and_logprob(self):
s = self.rsample()
log_p = self.log_prob(s)
return s, log_p
def mle_estimate(self):
return self.mean
def get_diagnostics(self):
return {}
class TorchDistributionWrapper(Distribution):
def __init__(self, distribution: TorchDistribution):
self.distribution = distribution
@property
def batch_shape(self):
return self.distribution.batch_shape
@property
def event_shape(self):
return self.distribution.event_shape
@property
def arg_constraints(self):
return self.distribution.arg_constraints
@property
def support(self):
return self.distribution.support
@property
def mean(self):
return self.distribution.mean
@property
def variance(self):
return self.distribution.variance
@property
def stddev(self):
return self.distribution.stddev
def sample(self, sample_size=torch.Size()):
return self.distribution.sample(sample_shape=sample_size)
def rsample(self, sample_size=torch.Size()):
return self.distribution.rsample(sample_shape=sample_size)
def log_prob(self, value):
return self.distribution.log_prob(value)
def cdf(self, value):
return self.distribution.cdf(value)
def icdf(self, value):
return self.distribution.icdf(value)
def enumerate_support(self, expand=True):
return self.distribution.enumerate_support(expand=expand)
def entropy(self):
return self.distribution.entropy()
def perplexity(self):
return self.distribution.perplexity()
def __repr__(self):
return 'Wrapped ' + self.distribution.__repr__()
class Independent(Distribution, TorchIndependent):
def get_diagnostics(self):
return self.base_dist.get_diagnostics()
class MultivariateDiagonalNormal(TorchDistributionWrapper):
from torch.distributions import constraints
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
def __init__(self, loc, scale_diag, reinterpreted_batch_ndims=1):
dist = Independent(TorchNormal(loc, scale_diag),
reinterpreted_batch_ndims=reinterpreted_batch_ndims)
super().__init__(dist)
def get_diagnostics(self):
stats = OrderedDict()
stats.update(ptu.create_stats_ordered_dict(
'mean',
ptu.get_numpy(self.mean),
# exclude_max_min=True,
))
stats.update(ptu.create_stats_ordered_dict(
'std',
ptu.get_numpy(self.distribution.stddev),
))
return stats
def __repr__(self):
return self.distribution.base_dist.__repr__()
class TanhNormal(Distribution):
"""
Represent distribution of X where
X ~ tanh(Z)
Z ~ N(mean, std)
Note: this is not very numerically stable.
"""
def __init__(self, normal_mean, normal_std, epsilon=1e-6):
"""
:param normal_mean: Mean of the normal distribution
:param normal_std: Std of the normal distribution
:param epsilon: Numerical stability epsilon when computing log-prob.
"""
self.normal_mean = normal_mean
self.normal_std = normal_std
self.normal = MultivariateDiagonalNormal(normal_mean, normal_std)
self.epsilon = epsilon
def sample_n(self, n, return_pre_tanh_value=False):
z = self.normal.sample_n(n)
if return_pre_tanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def _log_prob_from_pre_tanh(self, pre_tanh_value):
"""
Adapted from
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73
This formula is mathematically equivalent to log(1 - tanh(x)^2).
Derivation:
log(1 - tanh(x)^2)
= log(sech(x)^2)
= 2 * log(sech(x))
= 2 * log(2e^-x / (e^-2x + 1))
= 2 * (log(2) - x - log(e^-2x + 1))
= 2 * (log(2) - x - softplus(-2x))
:param value: some value, x
:param pre_tanh_value: arctanh(x)
:return:
"""
log_prob = self.normal.log_prob(pre_tanh_value)
correction = - 2. * (
ptu.from_numpy(np.log([2.]))
- pre_tanh_value
- torch.nn.functional.softplus(-2. * pre_tanh_value)
).sum(dim=1)
return log_prob + correction
def log_prob(self, value, pre_tanh_value=None):
if pre_tanh_value is None:
# errors or instability at values near 1
value = torch.clamp(value, -0.999999, 0.999999)
pre_tanh_value = torch.log(1+value) / 2 - torch.log(1-value) / 2
return self._log_prob_from_pre_tanh(pre_tanh_value)
def rsample_with_pretanh(self):
z = (
self.normal_mean +
self.normal_std *
MultivariateDiagonalNormal(
ptu.zeros(self.normal_mean.size()),
ptu.ones(self.normal_std.size())
).sample()
)
return torch.tanh(z), z
def sample(self):
"""
Gradients will and should *not* pass through this operation.
See https://github.com/pytorch/pytorch/issues/4620 for discussion.
"""
value, pre_tanh_value = self.rsample_with_pretanh()
return value.detach()
def rsample(self):
"""
Sampling in the reparameterization case.
"""
value, pre_tanh_value = self.rsample_with_pretanh()
return value
def sample_and_logprob(self):
value, pre_tanh_value = self.rsample_with_pretanh()
value, pre_tanh_value = value.detach(), pre_tanh_value.detach()
log_p = self.log_prob(value, pre_tanh_value)
return value, log_p
def rsample_and_logprob(self):
value, pre_tanh_value = self.rsample_with_pretanh()
log_p = self.log_prob(value, pre_tanh_value)
return value, log_p
def rsample_logprob_and_pretanh(self):
value, pre_tanh_value = self.rsample_with_pretanh()
log_p = self.log_prob(value, pre_tanh_value)
return value, log_p, pre_tanh_value
@property
def mean(self):
return torch.tanh(self.normal_mean)
@property
def stddev(self):
return self.normal_std
def get_diagnostics(self):
stats = OrderedDict()
stats.update(ptu.create_stats_ordered_dict(
'mean',
ptu.get_numpy(self.mean),
))
stats.update(ptu.create_stats_ordered_dict(
'normal/std',
ptu.get_numpy(self.normal_std)
))
stats.update(ptu.create_stats_ordered_dict(
'normal/log_std',
ptu.get_numpy(torch.log(self.normal_std)),
))
return stats
def torch_ify(np_array_or_other):
if isinstance(np_array_or_other, np.ndarray):
return ptu.from_numpy(np_array_or_other)
else:
return np_array_or_other
def np_ify(tensor_or_other):
if isinstance(tensor_or_other, torch.autograd.Variable):
return ptu.get_numpy(tensor_or_other)
else:
return tensor_or_other
def elem_or_tuple_to_numpy(elem_or_tuple):
if isinstance(elem_or_tuple, tuple):
return tuple(np_ify(x) for x in elem_or_tuple)
else:
return np_ify(elem_or_tuple)
class LayerNorm(nn.Module):
"""
Simple 1D LayerNorm.
"""
def __init__(self, features, center=True, scale=False, eps=1e-6):
super().__init__()
self.center = center
self.scale = scale
self.eps = eps
if self.scale:
self.scale_param = nn.Parameter(torch.ones(features))
else:
self.scale_param = None
if self.center:
self.center_param = nn.Parameter(torch.zeros(features))
else:
self.center_param = None
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
output = (x - mean) / (std + self.eps)
if self.scale:
output = output * self.scale_param
if self.center:
output = output + self.center_param
return output
class Mlp(nn.Module):
def __init__(
self,
hidden_sizes,
output_size,
input_size,
init_w=3e-3,
hidden_activation=F.relu,
output_activation=ptu.identity,
hidden_init=ptu.fanin_init,
b_init_value=0.,
layer_norm=False,
layer_norm_kwargs=None,
):
super().__init__()
if layer_norm_kwargs is None:
layer_norm_kwargs = dict()
self.input_size = input_size
self.output_size = output_size
self.hidden_activation = hidden_activation
self.output_activation = output_activation
self.layer_norm = layer_norm
self.fcs = []
self.layer_norms = []
in_size = input_size
for i, next_size in enumerate(hidden_sizes):
fc = nn.Linear(in_size, next_size)
in_size = next_size
hidden_init(fc.weight)
fc.bias.data.fill_(b_init_value)
self.__setattr__("fc{}".format(i), fc)
self.fcs.append(fc)
if self.layer_norm:
ln = LayerNorm(next_size)
self.__setattr__("layer_norm{}".format(i), ln)
self.layer_norms.append(ln)
self.last_fc = nn.Linear(in_size, output_size)
self.last_fc.weight.data.uniform_(-init_w, init_w)
self.last_fc.bias.data.fill_(0)
def forward(self, input, return_preactivations=False):
h = input
for i, fc in enumerate(self.fcs):
h = fc(h)
if self.layer_norm and i < len(self.fcs) - 1:
h = self.layer_norms[i](h)
h = self.hidden_activation(h)
preactivation = self.last_fc(h)
output = self.output_activation(preactivation)
if return_preactivations:
return output, preactivation
else:
return output
class DistributionGenerator(nn.Module, metaclass=abc.ABCMeta):
def forward(self, *input, **kwarg) -> Distribution:
raise NotImplementedError
class MultiInputSequential(nn.Sequential):
def forward(self, *input):
for module in self._modules.values():
if isinstance(input, tuple):
input = module(*input)
else:
input = module(input)
return input
class ModuleToDistributionGenerator(
MultiInputSequential,
DistributionGenerator,
metaclass=abc.ABCMeta
):
pass
class Beta(ModuleToDistributionGenerator):
def forward(self, *input):
alpha, beta = super().forward(*input)
return Beta(alpha, beta)
class Gaussian(ModuleToDistributionGenerator):
def __init__(self, module, std=None, reinterpreted_batch_ndims=1):
super().__init__(module)
self.std = std
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
def forward(self, *input):
if self.std:
mean = super().forward(*input)
std = self.std
else:
mean, log_std = super().forward(*input)
std = log_std.exp()
return MultivariateDiagonalNormal(
mean, std, reinterpreted_batch_ndims=self.reinterpreted_batch_ndims)
class Bernoulli(Distribution, TorchBernoulli):
def get_diagnostics(self):
stats = OrderedDict()
stats.update(ptu.create_stats_ordered_dict(
'probability',
ptu.get_numpy(self.probs),
))
return stats
class BernoulliGenerator(ModuleToDistributionGenerator):
def forward(self, *input):
probs = super().forward(*input)
return Bernoulli(probs)
class IndependentGenerator(ModuleToDistributionGenerator):
def __init__(self, *args, reinterpreted_batch_ndims=1):
super().__init__(*args)
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
def forward(self, *input):
distribution = super().forward(*input)
return Independent(
distribution,
reinterpreted_batch_ndims=self.reinterpreted_batch_ndims,
)
class GaussianMixtureDistribution(Distribution):
def __init__(self, normal_means, normal_stds, weights):
self.num_gaussians = weights.shape[1]
self.normal_means = normal_means
self.normal_stds = normal_stds
self.normal = MultivariateDiagonalNormal(normal_means, normal_stds)
self.normals = [MultivariateDiagonalNormal(normal_means[:, :, i], normal_stds[:, :, i]) for i in range(self.num_gaussians)]
self.weights = weights
self.categorical = OneHotCategorical(self.weights[:, :, 0])
def log_prob(self, value, ):
log_p = [self.normals[i].log_prob(value) for i in range(self.num_gaussians)]
log_p = torch.stack(log_p, -1)
log_p = log_p.sum(dim=1)
log_weights = torch.log(self.weights[:, :, 0])
lp = log_weights + log_p
m = lp.max(dim=1)[0] # log-sum-exp numerical stability trick
log_p_mixture = m + torch.log(torch.exp(lp - m).sum(dim=1))
return log_p_mixture
def sample(self):
z = self.normal.sample().detach()
c = self.categorical.sample()[:, :, None]
s = torch.matmul(z, c)
return torch.squeeze(s, 2)
def rsample(self):
z = (
self.normal_means +
self.normal_stds *
MultivariateDiagonalNormal(
ptu.zeros(self.normal_means.size()),
ptu.ones(self.normal_stds.size())
).sample()
)
z.requires_grad_()
c = self.categorical.sample()[:, :, None]
s = torch.matmul(z, c)
return torch.squeeze(s, 2)
def mle_estimate(self):
"""Return the mean of the most likely component.
This often computes the mode of the distribution, but not always.
"""
c = ptu.zeros(self.weights.shape[:2])
ind = torch.argmax(self.weights, dim=1) # [:, 0]
c.scatter_(1, ind, 1)
s = torch.matmul(self.normal_means, c[:, :, None])
return torch.squeeze(s, 2)
def __repr__(self):
s = "GaussianMixture(normal_means=%s, normal_stds=%s, weights=%s)"
return s % (self.normal_means, self.normal_stds, self.weights)
class GaussianMixtureFullDistribution(Distribution):
def __init__(self, normal_means, normal_stds, weights):
self.num_gaussians = weights.shape[-1]
self.normal_means = normal_means
self.normal_stds = normal_stds
self.normal = MultivariateDiagonalNormal(normal_means, normal_stds)
self.normals = [MultivariateDiagonalNormal(normal_means[:, :, i], normal_stds[:, :, i]) for i in range(self.num_gaussians)]
self.weights = (weights + epsilon) / (1 + epsilon * self.num_gaussians)
assert (self.weights > 0).all()
self.categorical = Categorical(self.weights)
def log_prob(self, value, ):
log_p = [self.normals[i].log_prob(value) for i in range(self.num_gaussians)]
log_p = torch.stack(log_p, -1)
log_weights = torch.log(self.weights)
lp = log_weights + log_p
m = lp.max(dim=2, keepdim=True)[0] # log-sum-exp numerical stability trick
log_p_mixture = m + torch.log(torch.exp(lp - m).sum(dim=2, keepdim=True))
raise NotImplementedError("from Vitchyr: idk what the point is of "
"this class, so I didn't both updating "
"this, but log_prob should return something "
"of shape [batch_size] and not [batch_size, "
"1] to be in accordance with the "
"torch.distributions.Distribution "
"interface.")
return torch.squeeze(log_p_mixture, 2)
def sample(self):
z = self.normal.sample().detach()
c = self.categorical.sample()[:, :, None]
s = torch.gather(z, dim=2, index=c)
return s[:, :, 0]
def rsample(self):
z = (
self.normal_means +
self.normal_stds *
MultivariateDiagonalNormal(
ptu.zeros(self.normal_means.size()),
ptu.ones(self.normal_stds.size())
).sample()
)
z.requires_grad_()
c = self.categorical.sample()[:, :, None]
s = torch.gather(z, dim=2, index=c)
return s[:, :, 0]
def mle_estimate(self):
"""Return the mean of the most likely component.
This often computes the mode of the distribution, but not always.
"""
ind = torch.argmax(self.weights, dim=2)[:, :, None]
means = torch.gather(self.normal_means, dim=2, index=ind)
return torch.squeeze(means, 2)
def __repr__(self):
s = "GaussianMixture(normal_means=%s, normal_stds=%s, weights=%s)"
return s % (self.normal_means, self.normal_stds, self.weights)
class GaussianMixture(ModuleToDistributionGenerator):
def forward(self, *input):
mixture_means, mixture_stds, weights = super().forward(*input)
return GaussianMixtureDistribution(mixture_means, mixture_stds, weights)
class GaussianMixtureFull(ModuleToDistributionGenerator):
def forward(self, *input):
mixture_means, mixture_stds, weights = super().forward(*input)
return GaussianMixtureFullDistribution(mixture_means, mixture_stds, weights)
class TanhGaussian(ModuleToDistributionGenerator):
def forward(self, *input):
mean, log_std = super().forward(*input)
std = log_std.exp()
return TanhNormal(mean, std)
class Policy(object, metaclass=abc.ABCMeta):
"""
General policy interface.
"""
@abc.abstractmethod
def get_action(self, observation):
"""
:param observation:
:return: action, debug_dictionary
"""
pass
def reset(self):
pass
class ExplorationPolicy(Policy, metaclass=abc.ABCMeta):
def set_num_steps_total(self, t):
pass
class TorchStochasticPolicy(
DistributionGenerator,
ExplorationPolicy, metaclass=abc.ABCMeta
):
def get_action(self, obs_np, ):
actions = self.get_actions(obs_np[None])
return actions[0, :], {}
def get_actions(self, obs_np, ):
dist = self._get_dist_from_np(obs_np)
actions = dist.sample()
return elem_or_tuple_to_numpy(actions)
def _get_dist_from_np(self, *args, **kwargs):
torch_args = tuple(torch_ify(x) for x in args)
torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
dist = self(*torch_args, **torch_kwargs)
return dist
class Delta(Distribution):
"""A deterministic distribution"""
def __init__(self, value):
self.value = value
def sample(self):
return self.value.detach()
def rsample(self):
return self.value
@property
def mean(self):
return self.value
@property
def variance(self):
return 0
@property
def entropy(self):
return 0
class TanhGaussianPolicy(Mlp, TorchStochasticPolicy):
"""
Usage:
```
policy = TanhGaussianPolicy(...)
"""
def __init__(
self,
hidden_sizes,
obs_dim,
action_dim,
std=None,
init_w=1e-3,
**kwargs
):
super().__init__(
hidden_sizes,
input_size=obs_dim,
output_size=action_dim,
init_w=init_w,
**kwargs
)
self.log_std = None
self.std = std
if std is None:
last_hidden_size = obs_dim
if len(hidden_sizes) > 0:
last_hidden_size = hidden_sizes[-1]
self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
self.last_fc_log_std.weight.data.uniform_(-init_w, init_w)
self.last_fc_log_std.bias.data.uniform_(-init_w, init_w)
else:
self.log_std = np.log(std)
assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX
def forward(self, obs):
h = obs
for i, fc in enumerate(self.fcs):
h = self.hidden_activation(fc(h))
mean = self.last_fc(h)
if self.std is None:
log_std = self.last_fc_log_std(h)
log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
std = torch.exp(log_std)
else:
std = torch.from_numpy(np.array([self.std, ])).float().to(
ptu.device)
return TanhNormal(mean, std)
def logprob(self, action, mean, std):
tanh_normal = TanhNormal(mean, std)
log_prob = tanh_normal.log_prob(
action,
)
log_prob = log_prob.sum(dim=1, keepdim=True)
return log_prob
def get_action(self, obs_np, ):
actions = self.get_actions(obs_np[None])
return actions[0, :], {}
def get_actions(self, obs_np, ):
dist = self._get_dist_from_np(obs_np)
actions = dist.sample()
return elem_or_tuple_to_numpy(actions)
def _get_dist_from_np(self, *args, **kwargs):
torch_args = tuple(torch_ify(x) for x in args)
torch_kwargs = {k: torch_ify(v) for k, v in kwargs.items()}
dist = self(*torch_args, **torch_kwargs)
return dist
class MakeDeterministic(TorchStochasticPolicy):
def __init__(
self,
action_distribution_generator: DistributionGenerator,
):
super().__init__()
self._action_distribution_generator = action_distribution_generator
def forward(self, *args, **kwargs):
dist = self._action_distribution_generator.forward(*args, **kwargs)
return Delta(dist.mle_estimate())
def get_resnet(name: str, weights=None, **kwargs) -> nn.Module:
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", None
"""
# Use standard ResNet implementation from torchvision
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
# remove the final fully connected layer
# for resnet18, the output dim should be 512
resnet.fc = torch.nn.Identity()
return resnet
def replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
Replace all submodules selected by the predicate with
the output of func.
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all modules are replaced
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
assert len(bn_list) == 0
return root_module
def replace_bn_with_gn(
root_module: nn.Module,
features_per_group: int = 16) -> nn.Module:
"""
Relace all BatchNorm layers with GroupNorm.
"""
replace_submodules(
root_module=root_module,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // features_per_group,
num_channels=x.num_features)
)
return root_module
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels * 2
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
nn.Unflatten(-1, (-1, 1))
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
global_cond_dim,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
kernel_size=5,
n_groups=8
):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines numebr of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
super().__init__()
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
])
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out * 2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
print("number of parameters: {:e}".format(
sum(p.numel() for p in self.parameters()))
)
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
global_cond=None):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
# (B,T,C)
sample = sample.moveaxis(-1, -2)
# (B,C,T)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None: