From 29dd7dadacc9fcfe30d3a2809acc36932f781652 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Tue, 25 Apr 2023 13:01:48 +0800 Subject: [PATCH 01/14] merge dreamer ConvEncoder to model/encoder.py --- ding/model/common/encoder.py | 63 ++++++++++++++++++++++--- ding/model/common/tests/test_encoder.py | 11 +++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index 3c1cbd6738..165922e3f8 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -35,6 +35,7 @@ def __init__( kernel_size: SequenceType = [8, 4, 3], stride: SequenceType = [4, 2, 1], padding: Optional[SequenceType] = None, + layer_norm: Optional[bool] = False, norm_type: Optional[str] = None ) -> None: """ @@ -50,6 +51,7 @@ def __init__( - stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers. - padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \ See ``nn.Conv2d`` for more details. Default is ``None``. + - layer_norm (:obj:`bool`): Whether to use ``DreamerLayerNorm``. - norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResBlock`` \ for more details. Default is ``None``. """ @@ -63,17 +65,24 @@ def __init__( layers = [] input_size = obs_shape[0] # in_channel for i in range(len(kernel_size)): - layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i])) - layers.append(self.act) + if layer_norm: + layers.append(Conv2dSame(in_channels=input_size, out_channels=hidden_size_list[i], kernel_size=(kernel_size[i], kernel_size[i]), stride=(2, 2), bias=False,)) + layers.append(DreamerLayerNorm(hidden_size_list[i])) + layers.append(self.act) + else: + layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i])) + layers.append(self.act) input_size = hidden_size_list[i] - assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts" - for i in range(3, len(self.hidden_size_list) - 1): - layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type)) + if len(self.hidden_size_list) >= len(kernel_size)+2: + assert self.hidden_size_list[len(kernel_size)-1] == self.hidden_size_list[len(kernel_size)], "Please indicate the same hidden size between conv and res block" + assert len(set(hidden_size_list[len(kernel_size):-1])) <= 1, "Please indicate the same hidden size for res block parts" + for i in range(len(kernel_size), len(self.hidden_size_list) - 1): + layers.append(ResBlock(self.hidden_size_list[i-1], activation=self.act, norm_type=norm_type)) layers.append(Flatten()) self.main = nn.Sequential(*layers) flatten_size = self._get_flatten_size() - self.output_size = hidden_size_list[-1] + self.output_size = hidden_size_list[-1] # outside to use self.mid = nn.Linear(flatten_size, hidden_size_list[-1]) def _get_flatten_size(self) -> int: @@ -306,3 +315,45 @@ def forward(self, x): if self.final_relu: x = torch.relu(x) return x + + +class Conv2dSame(torch.nn.Conv2d): + def calc_same_pad(self, i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x): + ih, iw = x.size()[-2:] + pad_h = self.calc_same_pad( + i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] + ) + pad_w = self.calc_same_pad( + i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] + ) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + + ret = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return ret + + +class DreamerLayerNorm(nn.Module): + def __init__(self, ch, eps=1e-03): + super(DreamerLayerNorm, self).__init__() + self.norm = torch.nn.LayerNorm(ch, eps=eps) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = x.permute(0, 3, 1, 2) + return x \ No newline at end of file diff --git a/ding/model/common/tests/test_encoder.py b/ding/model/common/tests/test_encoder.py index 26fae7b22c..52e68749e7 100644 --- a/ding/model/common/tests/test_encoder.py +++ b/ding/model/common/tests/test_encoder.py @@ -23,6 +23,14 @@ def test_conv_encoder(self): outputs = model(inputs) self.output_check(model, outputs) assert outputs.shape == (B, 128) + + def test_dreamer_conv_encoder(self): + inputs = torch.randn(B, C, H, W) + model = ConvEncoder((C, H, W), hidden_size_list=[32, 64, 128, 256, 128], activation=torch.nn.SiLU(), kernel_size=[4, 4, 4, 4], layer_norm=True) + print(model) + outputs = model(inputs) + self.output_check(model, outputs) + assert outputs.shape == (B, 128) def test_fc_encoder(self): inputs = torch.randn(B, 32) @@ -47,3 +55,6 @@ def test_impalaconv_encoder(self): outputs = model(inputs) self.output_check(model, outputs) assert outputs.shape == (B, 256) + +a = TestEncoder() +a.test_dreamer_conv_encoder() \ No newline at end of file From 6fd5970c80ee593fc3e10d48b053ce5dbad8d77f Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Tue, 25 Apr 2023 14:51:25 +0800 Subject: [PATCH 02/14] polish format --- ding/model/common/encoder.py | 37 +++++++++++++++---------- ding/model/common/tests/test_encoder.py | 13 +++++++-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index 165922e3f8..fa6e1fed41 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -66,18 +66,29 @@ def __init__( input_size = obs_shape[0] # in_channel for i in range(len(kernel_size)): if layer_norm: - layers.append(Conv2dSame(in_channels=input_size, out_channels=hidden_size_list[i], kernel_size=(kernel_size[i], kernel_size[i]), stride=(2, 2), bias=False,)) + layers.append( + Conv2dSame( + in_channels=input_size, + out_channels=hidden_size_list[i], + kernel_size=(kernel_size[i], kernel_size[i]), + stride=(2, 2), + bias=False, + ) + ) layers.append(DreamerLayerNorm(hidden_size_list[i])) layers.append(self.act) else: layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i])) layers.append(self.act) input_size = hidden_size_list[i] - if len(self.hidden_size_list) >= len(kernel_size)+2: - assert self.hidden_size_list[len(kernel_size)-1] == self.hidden_size_list[len(kernel_size)], "Please indicate the same hidden size between conv and res block" - assert len(set(hidden_size_list[len(kernel_size):-1])) <= 1, "Please indicate the same hidden size for res block parts" + if len(self.hidden_size_list) >= len(kernel_size) + 2: + assert self.hidden_size_list[len(kernel_size) - 1] == self.hidden_size_list[ + len(kernel_size)], "Please indicate the same hidden size between conv and res block" + assert len( + set(hidden_size_list[len(kernel_size):-1]) + ) <= 1, "Please indicate the same hidden size for res block parts" for i in range(len(kernel_size), len(self.hidden_size_list) - 1): - layers.append(ResBlock(self.hidden_size_list[i-1], activation=self.act, norm_type=norm_type)) + layers.append(ResBlock(self.hidden_size_list[i - 1], activation=self.act, norm_type=norm_type)) layers.append(Flatten()) self.main = nn.Sequential(*layers) @@ -318,22 +329,17 @@ def forward(self, x): class Conv2dSame(torch.nn.Conv2d): + def calc_same_pad(self, i, k, s, d): return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x): ih, iw = x.size()[-2:] - pad_h = self.calc_same_pad( - i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] - ) - pad_w = self.calc_same_pad( - i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] - ) + pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) + pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) if pad_h > 0 or pad_w > 0: - x = F.pad( - x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] - ) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) ret = F.conv2d( x, @@ -348,6 +354,7 @@ def forward(self, x): class DreamerLayerNorm(nn.Module): + def __init__(self, ch, eps=1e-03): super(DreamerLayerNorm, self).__init__() self.norm = torch.nn.LayerNorm(ch, eps=eps) @@ -356,4 +363,4 @@ def forward(self, x): x = x.permute(0, 2, 3, 1) x = self.norm(x) x = x.permute(0, 3, 1, 2) - return x \ No newline at end of file + return x diff --git a/ding/model/common/tests/test_encoder.py b/ding/model/common/tests/test_encoder.py index 52e68749e7..ef31550690 100644 --- a/ding/model/common/tests/test_encoder.py +++ b/ding/model/common/tests/test_encoder.py @@ -23,10 +23,16 @@ def test_conv_encoder(self): outputs = model(inputs) self.output_check(model, outputs) assert outputs.shape == (B, 128) - + def test_dreamer_conv_encoder(self): inputs = torch.randn(B, C, H, W) - model = ConvEncoder((C, H, W), hidden_size_list=[32, 64, 128, 256, 128], activation=torch.nn.SiLU(), kernel_size=[4, 4, 4, 4], layer_norm=True) + model = ConvEncoder( + (C, H, W), + hidden_size_list=[32, 64, 128, 256, 128], + activation=torch.nn.SiLU(), + kernel_size=[4, 4, 4, 4], + layer_norm=True + ) print(model) outputs = model(inputs) self.output_check(model, outputs) @@ -56,5 +62,6 @@ def test_impalaconv_encoder(self): self.output_check(model, outputs) assert outputs.shape == (B, 256) + a = TestEncoder() -a.test_dreamer_conv_encoder() \ No newline at end of file +a.test_dreamer_conv_encoder() From 0aad131c37150ef744a92c36f8cb23e025124ec7 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Thu, 18 May 2023 17:39:39 +0800 Subject: [PATCH 03/14] add DenseHead of dreamerv3 to world_model/model --- ding/world_model/model/networks.py | 72 +++++ ding/world_model/model/tests/test_networks.py | 23 ++ ding/world_model/utils.py | 257 ++++++++++++++++++ 3 files changed, 352 insertions(+) create mode 100644 ding/world_model/model/networks.py create mode 100644 ding/world_model/model/tests/test_networks.py diff --git a/ding/world_model/model/networks.py b/ding/world_model/model/networks.py new file mode 100644 index 0000000000..dc4e096bdd --- /dev/null +++ b/ding/world_model/model/networks.py @@ -0,0 +1,72 @@ +import math +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torch import distributions as torchd + +from ding.world_model.utils import weight_init, uniform_weight_init, ContDist, Bernoulli, TwoHotDistSymlog, UnnormalizedHuber +from ding.torch_utils import MLP, fc_block + + +class DenseHead(nn.Module): + + def __init__( + self, + inp_dim, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter + shape, # (255,) + layer_num, + units, # 512 + act='SiLU', + norm='LN', + dist='normal', + std=1.0, + outscale=1.0, + ): + super(DenseHead, self).__init__() + self._shape = (shape, ) if isinstance(shape, int) else shape + if len(self._shape) == 0: + self._shape = (1, ) + self._layer_num = layer_num + self._units = units + self._act = getattr(torch.nn, act)() + self._norm = norm + self._dist = dist + self._std = std + + self.mlp = MLP( + inp_dim, + self._units, + self._units, + self._layer_num, + layer_fn=nn.Linear, + activation=self._act, + norm_type=self._norm + ) + self.mlp.apply(weight_init) + + self.mean_layer = nn.Linear(self._units, np.prod(self._shape)) + self.mean_layer.apply(uniform_weight_init(outscale)) + + if self._std == "learned": + self.std_layer = nn.Linear(self._units, np.prod(self._shape)) + self.std_layer.apply(uniform_weight_init(outscale)) + + def forward(self, features, dtype=None): + x = features + out = self.mlp(x) # (batch, time, _units=512) + mean = self.mean_layer(out) # (batch, time, 255) + if self._std == "learned": + std = self.std_layer(out) + else: + std = self._std + if self._dist == "normal": + return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) + if self._dist == "huber": + return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) + if self._dist == "binary": + return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) + if self._dist == "twohot_symlog": + return TwoHotDistSymlog(logits=mean) + raise NotImplementedError(self._dist) diff --git a/ding/world_model/model/tests/test_networks.py b/ding/world_model/model/tests/test_networks.py new file mode 100644 index 0000000000..30173b20f1 --- /dev/null +++ b/ding/world_model/model/tests/test_networks.py @@ -0,0 +1,23 @@ +import pytest +import torch +from itertools import product +from ding.world_model.model.networks import DenseHead + +# arguments +shape = [255, (255, ), ()] +# to do +# dist = ['normal', 'huber', 'binary', 'twohot_symlog'] +dist = ['twohot_symlog'] +args = list(product(*[shape, dist])) + + +@pytest.mark.unittest +@pytest.mark.parametrize('shape, dist', args) +def test_DenseHead(shape, dist): + in_dim, layer_num, units, time, B = 1536, 2, 512, 16, 64 + head = DenseHead(in_dim, shape, layer_num, units, dist=dist) + x = torch.randn(time, B, in_dim) + a = torch.randn(time, B, 1) + y = head(x) + assert y.mode().shape == (time, B, 1) + assert y.log_prob(a).shape == (time, B) diff --git a/ding/world_model/utils.py b/ding/world_model/utils.py index 15172699f9..ff4e4ab904 100644 --- a/ding/world_model/utils.py +++ b/ding/world_model/utils.py @@ -1,5 +1,10 @@ from easydict import EasyDict from typing import Callable +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torch import distributions as torchd def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: @@ -23,3 +28,255 @@ def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: return lambda x: cfg.rollout_length else: raise KeyError("not implemented key: {}".format(cfg.type)) + + +def symlog(x): + return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + + +def symexp(x): + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + + +class SampleDist: + + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + samples = self._dist.sample(self._samples) + return torch.mean(samples, 0) + + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return sample[torch.argmax(logprob)][0] + + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -torch.mean(logprob, 0) + + +class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): + + def __init__(self, logits=None, probs=None, unimix_ratio=0.0): + if logits is not None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = torch.log(probs) + super().__init__(logits=logits, probs=None) + else: + super().__init__(logits=logits, probs=probs) + + def mode(self): + _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) + return _mode.detach() + super().logits - super().logits.detach() + + def sample(self, sample_shape=(), seed=None): + if seed is not None: + raise ValueError('need to check') + sample = super().sample(sample_shape) + probs = super().probs + while len(probs.shape) < len(sample.shape): + probs = probs[None] + sample += probs - probs.detach() + return sample + + +class TwoHotDistSymlog(): + + def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): + self.logits = logits + self.probs = torch.softmax(logits, -1) + self.buckets = torch.linspace(low, high, steps=255).to(device) + self.width = (self.buckets[-1] - self.buckets[0]) / 255 + + def mean(self): + print("mean called") + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + + def mode(self): + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + + # Inside OneHotCategorical, log_prob is calculated using only max element in targets + def log_prob(self, x): + x = symlog(x) + # x(time, batch, 1) + below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 + above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) + below = torch.clip(below, 0, len(self.buckets) - 1) + above = torch.clip(above, 0, len(self.buckets) - 1) + equal = (below == above) + + dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) + dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) + total = dist_to_below + dist_to_above + weight_below = dist_to_above / total + weight_above = dist_to_below / total + target = ( + F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] + ) + log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) + target = target.squeeze(-2) + + return (target * log_pred).sum(-1) + + def log_prob_target(self, target): + log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) + return (target * log_pred).sum(-1) + + +class SymlogDist(): + + def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): + self._mode = mode + self._dist = dist + self._agg = agg + self._tol = tol + self._dim_to_reduce = dim_to_reduce + + def mode(self): + return symexp(self._mode) + + def mean(self): + return symexp(self._mode) + + def log_prob(self, value): + assert self._mode.shape == value.shape + if self._dist == 'mse': + distance = (self._mode - symlog(value)) ** 2.0 + distance = torch.where(distance < self._tol, 0, distance) + elif self._dist == 'abs': + distance = torch.abs(self._mode - symlog(value)) + distance = torch.where(distance < self._tol, 0, distance) + else: + raise NotImplementedError(self._dist) + if self._agg == 'mean': + loss = distance.mean(self._dim_to_reduce) + elif self._agg == 'sum': + loss = distance.sum(self._dim_to_reduce) + else: + raise NotImplementedError(self._agg) + return -loss + + +class ContDist: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + return self._dist.mean + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + return self._dist.log_prob(x) + + +class Bernoulli: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + _mode = torch.round(self._dist.mean) + return _mode.detach() + self._dist.mean - self._dist.mean.detach() + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + _logits = self._dist.base_dist.logits + log_probs0 = -F.softplus(_logits) + log_probs1 = -F.softplus(-_logits) + + return log_probs0 * (1 - x) + log_probs1 * x + + +class UnnormalizedHuber(torchd.normal.Normal): + + def __init__(self, loc, scale, threshold=1, **kwargs): + super().__init__(loc, scale, **kwargs) + self._threshold = threshold + + def log_prob(self, event): + return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) + + def mode(self): + return self.mean + + +def weight_init(m): + if isinstance(m, nn.Linear): + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + space = m.kernel_size[0] * m.kernel_size[1] + in_num = space * m.in_channels + out_num = space * m.out_channels + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + + +def uniform_weight_init(given_scale): + + def f(m): + if isinstance(m, nn.Linear): + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = given_scale / denoms + limit = np.sqrt(3 * scale) + nn.init.uniform_(m.weight.data, a=-limit, b=limit) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + + return f From 1b3f02ab0c4272f8c9bed9e8ae4391ffbbd161ec Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Fri, 19 May 2023 15:52:28 +0800 Subject: [PATCH 04/14] add different Dist unittest --- ding/world_model/model/tests/test_networks.py | 10 ++-- .../tests/test_world_model_utils.py | 52 ++++++++++++++++++- ding/world_model/utils.py | 4 +- 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/ding/world_model/model/tests/test_networks.py b/ding/world_model/model/tests/test_networks.py index 30173b20f1..234f42fffc 100644 --- a/ding/world_model/model/tests/test_networks.py +++ b/ding/world_model/model/tests/test_networks.py @@ -14,10 +14,10 @@ @pytest.mark.unittest @pytest.mark.parametrize('shape, dist', args) def test_DenseHead(shape, dist): - in_dim, layer_num, units, time, B = 1536, 2, 512, 16, 64 + in_dim, layer_num, units, B, time = 1536, 2, 512, 16, 64 head = DenseHead(in_dim, shape, layer_num, units, dist=dist) - x = torch.randn(time, B, in_dim) - a = torch.randn(time, B, 1) + x = torch.randn(B, time, in_dim) + a = torch.randn(B, time, 1) y = head(x) - assert y.mode().shape == (time, B, 1) - assert y.log_prob(a).shape == (time, B) + assert y.mode().shape == (B, time, 1) + assert y.log_prob(a).shape == (B, time) diff --git a/ding/world_model/tests/test_world_model_utils.py b/ding/world_model/tests/test_world_model_utils.py index 26ba5e7f5e..d4477cd178 100644 --- a/ding/world_model/tests/test_world_model_utils.py +++ b/ding/world_model/tests/test_world_model_utils.py @@ -1,6 +1,9 @@ import pytest from easydict import EasyDict -from ding.world_model.utils import get_rollout_length_scheduler +import torch +from torch import distributions as torchd +from itertools import product +from ding.world_model.utils import get_rollout_length_scheduler, SampleDist, OneHotDist, TwoHotDistSymlog, SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init @pytest.mark.unittest @@ -17,3 +20,50 @@ def test_get_rollout_length_scheduler(): assert scheduler(19999) == 1 assert scheduler(150000) == 25 assert scheduler(1500000) == 25 + + +B, time = 16, 64 +mean = torch.randn(B, time, 255) +std = 1.0 +a = torch.randn(B, time, 1) # or torch.randn(B, time, 255) +sample_shape = torch.Size([]) + + +@pytest.mark.unittest +def test_ContDist(): + dist_origin = torchd.normal.Normal(mean, std) + dist = torchd.independent.Independent(dist_origin, 1) + dist_new = ContDist(dist) + assert dist_new.mode().shape == (B, time, 255) + assert dist_new.log_prob(a).shape == (B, time) + assert dist_origin.log_prob(a).shape == (B, time, 255) + assert dist_new.sample().shape == (B, time, 255) + + +@pytest.mark.unittest +def test_UnnormalizedHuber(): + dist_origin = UnnormalizedHuber(mean, std) + dist = torchd.independent.Independent(dist_origin, 1) + dist_new = ContDist(dist) + assert dist_new.mode().shape == (B, time, 255) + assert dist_new.log_prob(a).shape == (B, time) + assert dist_origin.log_prob(a).shape == (B, time, 255) + assert dist_new.sample().shape == (B, time, 255) + + +@pytest.mark.unittest +def test_Bernoulli(): + dist_origin = torchd.bernoulli.Bernoulli(logits=mean) + dist = torchd.independent.Independent(dist_origin, 1) + dist_new = Bernoulli(dist) + assert dist_new.mode().shape == (B, time, 255) + assert dist_new.log_prob(a).shape == (B, time, 255) + # to do + # assert dist_new.sample().shape == (B, time, 255) + + +@pytest.mark.unittest +def test_TwoHotDistSymlog(): + dist = TwoHotDistSymlog(logits=mean) + assert dist.mode().shape == (B, time, 1) + assert dist.log_prob(a).shape == (B, time) diff --git a/ding/world_model/utils.py b/ding/world_model/utils.py index ff4e4ab904..2f9f10ffdf 100644 --- a/ding/world_model/utils.py +++ b/ding/world_model/utils.py @@ -102,8 +102,8 @@ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): def mean(self): print("mean called") - _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + _mean = self.probs * self.buckets + return symexp(torch.sum(_mean, dim=-1, keepdim=True)) def mode(self): _mode = self.probs * self.buckets From 5c453c643d52cfcf7e3e27350af7e68a0d9fc87c Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Wed, 14 Jun 2023 20:52:08 +0800 Subject: [PATCH 05/14] add dreamer pipeline and dreamer world model --- ding/entry/serial_entry_mbrl.py | 69 +++ ding/model/common/encoder.py | 40 +- ding/model/template/vac.py | 178 +++++++ ding/policy/mbpolicy/__init__.py | 1 + ding/policy/mbpolicy/dreamer.py | 216 ++++++++ ding/policy/mbpolicy/utils.py | 128 ++++- ding/torch_utils/network/__init__.py | 1 + ding/torch_utils/network/dreamer.py | 41 ++ ding/world_model/dreamer.py | 248 +++++++++ ding/world_model/model/networks.py | 491 +++++++++++++++++- ding/world_model/utils.py | 71 +++ .../dmc2gym/config/dmc2gym_dreamer_config.py | 92 ++++ .../mbrl/halfcheetah_mbsac_mbpo_config.py | 110 ---- 13 files changed, 1533 insertions(+), 153 deletions(-) create mode 100644 ding/policy/mbpolicy/dreamer.py create mode 100644 ding/torch_utils/network/dreamer.py create mode 100644 ding/world_model/dreamer.py create mode 100644 dizoo/dmc2gym/config/dmc2gym_dreamer_config.py delete mode 100644 dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index 27e7be6c4d..e8dad68759 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -243,3 +243,72 @@ def serial_pipeline_dream( learner.call_hook('after_run') return policy + + +def serial_pipeline_dreamer( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry for dreamerv3. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \ + mbrl_entry_setup(input_cfg, seed, env_setting, model) + + learner.call_hook('before_run') + + # prefill environment buffer + if cfg.policy.get('random_collect_size', 0) > 0: + random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) + + while True: + collect_kwargs = commander.step() + # eval the policy + if evaluator.should_eval(collector.envstep): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + for i in range(eval_every): + # train world model and fill imagination buffer + steps = ( + cfg.world_model.pretrain + if world_model.should_pretrain() + else int(world_model.should_train(collector.envstep)) + ) + for _ in range(steps): + batch_size = learner.policy.get_attribute('batch_size') + post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size) + + start = post + + learner.train( + start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep) + ) + + # fill environment buffer + data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + env_buffer.push(data, cur_collector_envstep=collector.envstep) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + + return policy \ No newline at end of file diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index fa6e1fed41..d1e4f7d69b 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.nn import functional as F -from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d +from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d, Conv2dSame, DreamerLayerNorm from ding.utils import SequenceType @@ -326,41 +326,3 @@ def forward(self, x): if self.final_relu: x = torch.relu(x) return x - - -class Conv2dSame(torch.nn.Conv2d): - - def calc_same_pad(self, i, k, s, d): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - def forward(self, x): - ih, iw = x.size()[-2:] - pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) - pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) - - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - - ret = F.conv2d( - x, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups, - ) - return ret - - -class DreamerLayerNorm(nn.Module): - - def __init__(self, ch, eps=1e-03): - super(DreamerLayerNorm, self).__init__() - self.norm = torch.nn.LayerNorm(ch, eps=eps) - - def forward(self, x): - x = x.permute(0, 2, 3, 1) - x = self.norm(x) - x = x.permute(0, 3, 1, 2) - return x diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index cf026e0192..9b99b70908 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -356,3 +356,181 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict: action_type = self.actor_head[0](actor_embedding) action_args = self.actor_head[1](actor_embedding) return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value} + + +@MODEL_REGISTRY.register('dreamervac') +class DREAMERVAC(nn.Module): + r""" + Overview: + The VAC model. + Interfaces: + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__( + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType, EasyDict], + action_space: str = 'discrete', + share_encoder: bool = True, + encoder_hidden_size_list: SequenceType = [128, 128, 64], + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 1, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + sigma_type: Optional[str] = 'independent', + fixed_sigma_value: Optional[int] = 0.3, + bound_type: Optional[str] = None, + encoder: Optional[torch.nn.Module] = None, + impala_cnn_encoder: bool = False, + ) -> None: + r""" + Overview: + Init the VAC Model according to arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. + - action_shape (:obj:`Union[int, SequenceType]`): Action's space. + - action_space (:obj:`str`): Choose action head in ['discrete', 'continuous', 'hybrid'] + - share_encoder (:obj:`bool`): Whether share encoder. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. + - actor_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for actor's nn. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. + - critic_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for critic's nn. + - activation (:obj:`Optional[nn.Module]`): + The type of activation function to use in ``MLP`` the after ``layer_fn``, + if ``None`` then default set to ``nn.ReLU()`` + - norm_type (:obj:`Optional[str]`): + The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` + """ + super(VAC, self).__init__() + obs_shape: int = squeeze(obs_shape) + action_shape = squeeze(action_shape) + self.obs_shape, self.action_shape = obs_shape, action_shape + self.impala_cnn_encoder = impala_cnn_encoder + self.share_encoder = share_encoder + + # Encoder Type + def new_encoder(outsize): + if impala_cnn_encoder: + return IMPALAConvEncoder(obs_shape=obs_shape, channels=encoder_hidden_size_list, outsize=outsize) + else: + if isinstance(obs_shape, int) or len(obs_shape) == 1: + return FCEncoder( + obs_shape=obs_shape, + hidden_size_list=encoder_hidden_size_list, + activation=activation, + norm_type=norm_type + ) + elif len(obs_shape) == 3: + return ConvEncoder( + obs_shape=obs_shape, + hidden_size_list=encoder_hidden_size_list, + activation=activation, + norm_type=norm_type + ) + else: + raise RuntimeError( + "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". + format(obs_shape) + ) + + if self.share_encoder: + assert actor_head_hidden_size == critic_head_hidden_size, \ + "actor and critic network head should have same size." + if encoder: + if isinstance(encoder, torch.nn.Module): + self.encoder = encoder + else: + raise ValueError("illegal encoder instance.") + else: + self.encoder = new_encoder(actor_head_hidden_size) + else: + if encoder: + if isinstance(encoder, torch.nn.Module): + self.actor_encoder = encoder + self.critic_encoder = deepcopy(encoder) + else: + raise ValueError("illegal encoder instance.") + else: + self.actor_encoder = new_encoder(actor_head_hidden_size) + self.critic_encoder = new_encoder(critic_head_hidden_size) + + # Head Type + self.critic_head = RegressionHead( + critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type + ) + self.action_space = action_space + assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space + if self.action_space == 'continuous': + self.multi_head = False + self.actor_head = ReparameterizationHead( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + sigma_type=sigma_type, + activation=activation, + norm_type=norm_type, + bound_type=bound_type + ) + elif self.action_space == 'discrete': + actor_head_cls = DiscreteHead + multi_head = not isinstance(action_shape, int) + self.multi_head = multi_head + if multi_head: + self.actor_head = MultiHead( + actor_head_cls, + actor_head_hidden_size, + action_shape, + layer_num=actor_head_layer_num, + activation=activation, + norm_type=norm_type + ) + else: + self.actor_head = actor_head_cls( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + activation=activation, + norm_type=norm_type + ) + elif self.action_space == 'hybrid': # HPPO + # hybrid action space: action_type(discrete) + action_args(continuous), + # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} + action_shape.action_args_shape = squeeze(action_shape.action_args_shape) + action_shape.action_type_shape = squeeze(action_shape.action_type_shape) + actor_action_args = ReparameterizationHead( + actor_head_hidden_size, + action_shape.action_args_shape, + actor_head_layer_num, + sigma_type=sigma_type, + fixed_sigma_value=fixed_sigma_value, + activation=activation, + norm_type=norm_type, + bound_type=bound_type, + ) + actor_action_type = DiscreteHead( + actor_head_hidden_size, + action_shape.action_type_shape, + actor_head_layer_num, + activation=activation, + norm_type=norm_type, + ) + self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) + + # must use list, not nn.ModuleList + if self.share_encoder: + self.actor = [self.encoder, self.actor_head] + self.critic = [self.encoder, self.critic_head] + else: + self.actor = [self.actor_encoder, self.actor_head] + self.critic = [self.critic_encoder, self.critic_head] + # Convenient for calling some apis (e.g. self.critic.parameters()), + # but may cause misunderstanding when `print(self)` + self.actor = nn.ModuleList(self.actor) + self.critic = nn.ModuleList(self.critic) \ No newline at end of file diff --git a/ding/policy/mbpolicy/__init__.py b/ding/policy/mbpolicy/__init__.py index 7d528cd15b..e23c8d823d 100644 --- a/ding/policy/mbpolicy/__init__.py +++ b/ding/policy/mbpolicy/__init__.py @@ -1 +1,2 @@ from .mbsac import MBSACPolicy +from .dreamer import DREAMERPolicy diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py new file mode 100644 index 0000000000..4b31688f45 --- /dev/null +++ b/ding/policy/mbpolicy/dreamer.py @@ -0,0 +1,216 @@ +from typing import Dict, Any, List +import torch +from torch import nn +from copy import deepcopy +from ding.torch_utils import Adam, to_device +from ding.utils import POLICY_REGISTRY, deep_merge_dicts +from ding.policy import Policy +from ding.rl_utils import generalized_lambda_returns +from ding.model import model_wrap +from ding.policy.common_utils import default_preprocess_learn + +from .utils import imagine, compute_target, compute_actor_loss, RewardEMA, tensorstats + + +@POLICY_REGISTRY.register('dreamer') +class DREAMERPolicy(Policy): + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='dreamer', + # (bool) Whether to use cuda for network and loss computation. + cuda=False, + # (int) + imag_horizon=15, + learn=dict( + # (float) Lambda for TD-lambda return. + lambda_=0.95, + # (float) Max norm of gradients. + grad_clip=100, + learning_rate=0.001, + batch_size=256, + imag_sample=True, + slow_value_target=True, + discount=0.997, + reward_EMA=True, + actor_entropy=3e-4, + actor_state_entropy=0.0, + value_decay=0.0, + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'dreamervac', ['ding.model.template.vac'] + + def _init_learn(self) -> None: + r""" + Overview: + Learn mode init method. Called by ``self.__init__``. + Init the optimizer, algorithm config, main and target models. + """ + # Algorithm config + self._lambda = self._cfg.learn.lambda_ + self._grad_clip = self._cfg.learn.grad_clip + + self._critic = self._model.critic + self._actor = self._model.actor + + if self._cfg.learn.slow_value_target: + self._slow_value = deepcopy(self._critic) + self._updates = 0 + + # Optimizer + self._optimizer_value = Adam( + self._critic.parameters(), + lr=self._cfg.learn.learning_rate, + ) + self._optimizer_actor = Adam( + self._actor.parameters(), + lr=self._cfg.learn.learning_rate, + ) + + self._learn_model = model_wrap(self._model, wrapper_name='base') + self._learn_model.reset() + + self._forward_learn_cnt = 0 + + if self._cfg.learn.reward_EMA: + self.reward_ema = RewardEMA(device=self._device) + + def _forward_learn(self, start: dict, repeats=None, world_model=None, envstep) -> Dict[str, Any]: + # log dict + log_vars = {} + world_model.requires_grad_(requires_grad=False) + self._actor.requires_grad_(requires_grad=True) + # start is dict of {stoch, deter, logit} + if self._cuda: + start = to_device(start, self._device) + + self._learn_model.train() + self._target_model.train() + + # train self._actor + imag_feat, imag_state, imag_action = imagine( + self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon, repeats + ) + reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode() + actor_ent = self._actor(imag_feat).entropy() + state_ent = world_model.dynamics.get_dist(imag_state).entropy() + # this target is not scaled + # slow is flag to indicate whether slow_target is used for lambda-return + target, weights, base = compute_target( + self._cfg.learn, world_model, self._critic, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent + ) + actor_loss, mets = compute_actor_loss( + self._cfg.learn, + self._actor, + self.reward_ema, + imag_feat, + imag_state, + imag_action, + target, + actor_ent, + state_ent, + weights, + base, + ) + log_vars.update(mets) + value_input = imag_feat + self._actor.requires_grad_(requires_grad=False) + + self._critic.requires_grad_(requires_grad=True) + value = self._critic(value_input[:-1].detach()) + # to do + target = torch.stack(target, dim=1) + # (time, batch, 1), (time, batch, 1) -> (time, batch) + value_loss = -value.log_prob(target.detach()) + slow_target = self._slow_value(value_input[:-1].detach()) + if self._cfg.learn.slow_value_target: + value_loss = value_loss - value.log_prob( + slow_target.mode().detach() + ) + if self._cfg.learn.value_decay: + value_loss += self._cfg.learn.value_decay * value.mode() + # (time, batch, 1), (time, batch, 1) -> (1,) + value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) + self._critic.requires_grad_(requires_grad=False) + + log_vars.update(tensorstats(value.mode(), "value")) + log_vars.update(tensorstats(target, "target")) + log_vars.update(tensorstats(reward, "imag_reward")) + log_vars.update(tensorstats(imag_action, "imag_action")) + log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy() + # ==================== + # actor-critic update + # ==================== + self._model.requires_grad_(requires_grad=True) + + loss_dict = { + 'critic_loss': value_loss, + 'actor_loss': actor_loss, + } + + norm_dict = self._update(loss_dict) + + self._model.requires_grad_(requires_grad=False) + # ============= + # after update + # ============= + self._forward_learn_cnt += 1 + + return { + **log_vars, + **norm_dict, + **loss_dict, + } + + def _update(self, loss_dict): + # update actor + self._optimizer_actor.zero_grad() + loss_dict['actor_loss'].backward() + actor_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) + self._optimizer_actor.step() + # update critic + self._optimizer_value.zero_grad() + loss_dict['critic_loss'].backward() + critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) + self._optimizer_value.step() + return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm} + + def _monitor_vars_learn(self) -> List[str]: + r""" + Overview: + Return variables' name if variables are to used in monitor. + Returns: + - vars (:obj:`List[str]`): Variables' name list. + """ + return [ + 'normed_target_mean', + 'normed_target_std', + 'normed_target_min', + 'normed_target_max', + 'EMA_005', + 'EMA_095', + 'actor_entropy', + 'actor_state_entropy', + 'value_mean', + 'value_std', + 'value_min', + 'value_max', + 'target_mean', + 'target_std', + 'target_min', + 'target_max', + 'imag_reward_mean', + 'imag_reward_std', + 'imag_reward_min', + 'imag_reward_max', + 'imag_action_mean', + 'imag_action_std', + 'imag_action_min', + 'imag_action_max', + 'actor_ent', + 'actor_loss', + 'critic_loss', + 'actor_grad_norm', + 'critic_grad_norm' + ] diff --git a/ding/policy/mbpolicy/utils.py b/ding/policy/mbpolicy/utils.py index bd1d77d3f9..985e7dfae2 100644 --- a/ding/policy/mbpolicy/utils.py +++ b/ding/policy/mbpolicy/utils.py @@ -1,8 +1,8 @@ from typing import Callable, Tuple, Union - +import torch from torch import Tensor from ding.torch_utils import fold_batch, unfold_batch - +from ding.world_model.utils import static_scan def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor], Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]: @@ -36,3 +36,127 @@ def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, T if isinstance(q_values, list): return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)] return unfold_batch(q_values, dim) + + +def imagine(cfg, world_model, start, actor, horizon, repeats=None): + dynamics = world_model.dynamics + if repeats: + raise NotImplemented("repeats is not implemented in this version") + flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in start.items()} + + def step(prev, _): + state, _, _ = prev + feat = dynamics.get_feat(state) + inp = feat.detach() if self._stop_grad_actor else feat + action = actor(inp).sample() + succ = dynamics.img_step(state, action, sample=cfg.imag_sample) + return succ, feat, action + + succ, feats, actions = static_scan( + step, [torch.arange(horizon)], (start, None, None) + ) + states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} + if repeats: + raise NotImplemented("repeats is not implemented in this version") + + return feats, states, actions + + +def compute_target( + cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent +): + if "cont" in world_model.heads: + inp = self._world_model.dynamics.get_feat(imag_state) + discount = cfg.discount * self._world_model.heads["cont"](inp).mean + else: + discount = cfg.discount * torch.ones_like(reward) + + value = critic(imag_feat).mode() + # value(imag_horizon, 16*64, ch) + # action(imag_horizon, 16*64, ch) + # discount(imag_horizon, 16*64, ch) + target = tools.lambda_return( + reward[:-1], + value[:-1], + discount[:-1], + bootstrap=value[-1], + lambda_=cfg.lambda_, + axis=0, + ) + weights = torch.cumprod( + torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0 + ).detach() + return target, weights, value[:-1] + + +def compute_actor_loss( + cfg, + actor, + reward_ema + imag_feat, + imag_state, + imag_action, + target, + actor_ent, + state_ent, + weights, + base, + ): + metrics = {} + inp = imag_feat.detach() + policy = actor(inp) + actor_ent = policy.entropy() + # Q-val for actor is not transformed using symlog + target = torch.stack(target, dim=1) + if cfg.reward_EMA: + offset, scale = reward_ema(target) + normed_target = (target - offset) / scale + normed_base = (base - offset) / scale + adv = normed_target - normed_base + metrics.update(tools.tensorstats(normed_target, "normed_target")) + values = reward_ema.values + metrics["EMA_005"] = values[0].detach().cpu().numpy() + metrics["EMA_095"] = values[1].detach().cpu().numpy() + + actor_target = adv + if cfg.actor_entropy > 0: + actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] + actor_target += actor_entropy + metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy() + if cfg.actor_state_entropy > 0: + state_entropy = cfg.actor_state_entropy * state_ent[:-1] + actor_target += state_entropy + metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy() + actor_loss = -torch.mean(weights[:-1] * actor_target) + return actor_loss, metrics + + +class RewardEMA(object): + """running mean and std""" + + def __init__(self, device, alpha=1e-2): + self.device = device + self.values = torch.zeros((2,)).to(device) + self.alpha = alpha + self.range = torch.tensor([0.05, 0.95]).to(device) + + def __call__(self, x): + flat_x = torch.flatten(x.detach()) + x_quantile = torch.quantile(input=flat_x, q=self.range) + self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values + scale = torch.clip(self.values[1] - self.values[0], min=1.0) + offset = self.values[0] + return offset.detach(), scale.detach() + + +def tensorstats(tensor, prefix=None): + metrics = { + 'mean': torch.mean(tensor).detach().cpu().numpy(), + 'std': torch.std(tensor).detach().cpu().numpy(), + 'min': torch.min(tensor).detach().cpu().numpy(), + 'max': torch.max(tensor).detach().cpu().numpy(), + } + if prefix: + metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} + return metrics \ No newline at end of file diff --git a/ding/torch_utils/network/__init__.py b/ding/torch_utils/network/__init__.py index 252f48c71d..d1eb391c64 100644 --- a/ding/torch_utils/network/__init__.py +++ b/ding/torch_utils/network/__init__.py @@ -11,3 +11,4 @@ from .gumbel_softmax import GumbelSoftmax from .gtrxl import GTrXL, GRUGatingUnit from .popart import PopArt +from .dreamer import Conv2dSame, DreamerLayerNorm diff --git a/ding/torch_utils/network/dreamer.py b/ding/torch_utils/network/dreamer.py new file mode 100644 index 0000000000..8d45abc9cc --- /dev/null +++ b/ding/torch_utils/network/dreamer.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import math +from torch.nn import functional as F + +class Conv2dSame(torch.nn.Conv2d): + + def calc_same_pad(self, i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x): + ih, iw = x.size()[-2:] + pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) + pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) + + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + + ret = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return ret + + +class DreamerLayerNorm(nn.Module): + + def __init__(self, ch, eps=1e-03): + super(DreamerLayerNorm, self).__init__() + self.norm = torch.nn.LayerNorm(ch, eps=eps) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = x.permute(0, 3, 1, 2) + return x \ No newline at end of file diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py new file mode 100644 index 0000000000..fe1de29829 --- /dev/null +++ b/ding/world_model/dreamer.py @@ -0,0 +1,248 @@ +import numpy as np +import copy +import torch +from torch import nn + +from ding.utils import WORLD_MODEL_REGISTRY +from ding.utils.data import default_collate +from ding.model import ConvEncoder +from ding.world_model.base_world_model import WorldModel +from ding.world_model.model.networks import RSSM, ConvDecoder, DenseHead +from ding.torch_utils import fold_batch, unfold_batch, unsqueeze_repeat + + +@WORLD_MODEL_REGISTRY.register('dreamer') +class DREAMERWorldModel(WorldModel, nn.Module): + config = dict( + model=dict( + state_size=None, + action_size=None, + model_lr=1e-4, + reward_size=1, + hidden_size=200, + batch_size=256, + max_epochs_since_update=5, + dyn_stoch=32, + dyn_deter=512, + dyn_hidden=512, + dyn_input_layers=1, + dyn_output_layers=1, + dyn_rec_depth=1, + dyn_shared=False, + dyn_discrete=32, + act='SiLU', + norm='LayerNorm', + grad_heads=['image', 'reward', 'discount'], + units=512, + reward_layers=2, + discount_layers=2, + value_layers=2, + actor_layers=2, + act='SiLU', + norm='LayerNorm', + cnn_depth=32, + encoder_kernels=[4, 4, 4, 4], + decoder_kernels=[4, 4, 4, 4], + reward_head='twohot_symlog', + kl_lscale=0.1, + kl_rscale=0.5, + kl_free=1.0, + kl_forward=False, + pred_discount=True, + dyn_mean_act='none', + dyn_std_act='sigmoid2', + dyn_temp_post=True, + dyn_min_std=0.1, + dyn_cell=True, + unimix_ratio=0.01, + initial='learned', + device='cpu', + ), + ) + + def __init__(self, cfg, env, tb_logger): + WorldModel.__init__(self, cfg, env, tb_logger) + nn.Module.__init__(self) + + self._cfg = cfg.model + self.state_size = self._cfg.state_size + self.action_size = self._cfg.action_size + self.reward_size = self._cfg.reward_size + self.hidden_size = self._cfg.hidden_size + self.batch_size = self._cfg.batch_size + self.max_epochs_since_update = cfg.max_epochs_since_update + + if self._cuda: + self.cuda() + + self.encoder = ConvEncoder( + *self.state_size, + hidden_size_list=[32, 64, 128, 256, 128], # to last layer 128? + activation=torch.nn.SiLU(), + kernel_size=self._cfg.encoder_kernels, + layer_norm=True + ) + self.embed_size = ( + (self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * + 2 ** (len(self._cfg.encoder_kernels) - 1) + ) + self.dynamics = RSSM( + self._cfg.dyn_stoch, + self._cfg.dyn_deter, + self._cfg.dyn_hidden, + self._cfg.dyn_input_layers, + self._cfg.dyn_output_layers, + self._cfg.dyn_rec_depth, + self._cfg.dyn_shared, + self._cfg.dyn_discrete, + self._cfg.act, + self._cfg.norm, + self._cfg.dyn_mean_act, + self._cfg.dyn_std_act, + self._cfg.dyn_temp_post, + self._cfg.dyn_min_std, + self._cfg.dyn_cell, + self._cfg.unimix_ratio, + self._cfg.initial, + self._cfg.action_size, + self.embed_size, + self._cfg.device, + ) + self.heads = nn.ModuleDict() + if self._cfg.dyn_discrete: + feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter + else: + feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter + self.heads["image"] = ConvDecoder( + feat_size, # pytorch version + self._cfg.cnn_depth, + self._cfg.act, + self._cfg.norm, + self.state_size, + self._cfg.decoder_kernels, + ) + self.heads["reward"] = DenseHead( + feat_size, # dyn_stoch * dyn_discrete + dyn_deter + (255, ), + self._cfg.reward_layers, + self._cfg.units, + self._cfg.act, + self._cfg.norm, + dist=self._cfg.reward_head, + outscale=0.0, + ) + if self._cfg.pred_discount: + self.heads["discount"] = DenseHead( + feat_size, # pytorch version + [], + self._cfg.discount_layers, + self._cfg.units, + self._cfg.act, + self._cfg.norm, + dist="binary", + ) + # to do + # grad_clip, weight_decay + self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr) + + def step(self, obs, act): + pass + + def eval(self, env_buffer, envstep, train_iter): + pass + + def train(self, env_buffer, envstep, train_iter, batch_size): + self.last_train_step = envstep + data = env_buffer.sample(batch_size, train_iter) + data = default_collate(data) + data['done'] = data['done'].float() + data['weight'] = data.get('weight', None) + data = {k: torch.Tensor(v).to(self._cfg.device) for k, v in data.items()} + #image = data['obs'] + action = data['action'] + reward = data['reward'] + next_obs = data['next_obs'] + if len(reward.shape) == 2: + reward = reward.unsqueeze(-1) + if len(action.shape) == 2: + action = action.unsqueeze(-1) + + embed = self.encoder(data['obs']) + post, prior = self.dynamics.observe(embed, data["action"]) + kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( + post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale + ) + losses = {} + likes = {} + for name, head in self.heads.items(): + grad_head = name in self._cfg.grad_heads + feat = self.dynamics.get_feat(post) + feat = feat if grad_head else feat.detach() + pred = head(feat) + like = pred.log_prob(data[name]) + likes[name] = like + losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) + model_loss = sum(losses.values()) + kl_loss + + # ==================== + # world model update + # ==================== + self.optimizer.zero_grad() + model_loss.backward() + self.optimizer.step() + # log + if self.tb_logger is not None: + for name, loss in losses.items(): + self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy(), envstep) + self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep) + self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep) + self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep) + self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy(), envstep) + self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy(), envstep) + self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy(), envstep) + + prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy() + post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy() + + self.tb_logger.add_scalar('prior_ent', prior_ent, envstep) + self.tb_logger.add_scalar('post_ent', post_ent, envstep) + + context = dict( + embed=embed, + feat=self.dynamics.get_feat(post), + kl=kl_value, + postent=self.dynamics.get_dist(post).entropy(), + ) + post = {k: v.detach() for k, v in post.items()} + return post, context + + def _save_states(self, ): + self._states = copy.deepcopy(self.state_dict()) + + def _save_state(self, id): + state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'weight' in k or 'bias' in k: + self._states[k].data[id] = copy.deepcopy(v.data[id]) + + def _load_states(self): + self.load_state_dict(self._states) + + def _save_best(self, epoch, holdout_losses): + updated = False + for i in range(len(holdout_losses)): + current = holdout_losses[i] + _, best = self._snapshots[i] + improvement = (best - current) / best + if improvement > 0.01: + self._snapshots[i] = (epoch, current) + self._save_state(i) + # self._save_state(i) + updated = True + # improvement = (best - current) / best + + if updated: + self._epochs_since_update = 0 + else: + self._epochs_since_update += 1 + return self._epochs_since_update > self.max_epochs_since_update diff --git a/ding/world_model/model/networks.py b/ding/world_model/model/networks.py index dc4e096bdd..066020dc02 100644 --- a/ding/world_model/model/networks.py +++ b/ding/world_model/model/networks.py @@ -6,8 +6,366 @@ import torch.nn.functional as F from torch import distributions as torchd -from ding.world_model.utils import weight_init, uniform_weight_init, ContDist, Bernoulli, TwoHotDistSymlog, UnnormalizedHuber -from ding.torch_utils import MLP, fc_block +from ding.world_model.utils import weight_init, uniform_weight_init, OneHotDist, ContDist, SymlogDist, SampleDist, \ + Bernoulli, TwoHotDistSymlog, UnnormalizedHuber, SafeTruncatedNormal, TanhBijector, static_scan +from ding.torch_utils import MLP, DreamerLayerNorm + + +class RSSM(nn.Module): + + def __init__( + self, + stoch=30, + deter=200, + hidden=200, + layers_input=1, + layers_output=1, + rec_depth=1, + shared=False, + discrete=False, + act=nn.ELU, + norm=nn.LayerNorm, + mean_act="none", + std_act="softplus", + temp_post=True, + min_std=0.1, + cell="gru", + unimix_ratio=0.01, + num_actions=None, + embed=None, + device=None, + ): + super(RSSM, self).__init__() + self._stoch = stoch + self._deter = deter + self._hidden = hidden + self._min_std = min_std + self._layers_input = layers_input + self._layers_output = layers_output + self._rec_depth = rec_depth + self._shared = shared + self._discrete = discrete + self._act = act + self._norm = norm + self._mean_act = mean_act + self._std_act = std_act + self._temp_post = temp_post + self._unimix_ratio = unimix_ratio + self._embed = embed + self._device = device + + inp_layers = [] + if self._discrete: + inp_dim = self._stoch * self._discrete + num_actions + else: + inp_dim = self._stoch + num_actions + if self._shared: + inp_dim += self._embed + for i in range(self._layers_input): + inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + inp_layers.append(self._norm(self._hidden, eps=1e-03)) + inp_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._inp_layers = nn.Sequential(*inp_layers) + self._inp_layers.apply(weight_init) + + if cell == "gru": + self._cell = GRUCell(self._hidden, self._deter) + self._cell.apply(weight_init) + elif cell == "gru_layer_norm": + self._cell = GRUCell(self._hidden, self._deter, norm=True) + self._cell.apply(weight_init) + else: + raise NotImplementedError(cell) + + img_out_layers = [] + inp_dim = self._deter + for i in range(self._layers_output): + img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + img_out_layers.append(self._norm(self._hidden, eps=1e-03)) + img_out_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._img_out_layers = nn.Sequential(*img_out_layers) + self._img_out_layers.apply(weight_init) + + obs_out_layers = [] + if self._temp_post: + inp_dim = self._deter + self._embed + else: + inp_dim = self._embed + for i in range(self._layers_output): + obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) + obs_out_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._obs_out_layers = nn.Sequential(*obs_out_layers) + self._obs_out_layers.apply(weight_init) + + if self._discrete: + self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + self._ims_stat_layer.apply(weight_init) + self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + self._obs_stat_layer.apply(weight_init) + else: + self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._ims_stat_layer.apply(weight_init) + self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._obs_stat_layer.apply(weight_init) + + def initial(self, batch_size): + deter = torch.zeros(batch_size, self._deter).to(self._device) + if self._discrete: + state = dict( + logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), + stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), + deter=deter, + ) + else: + state = dict( + mean=torch.zeros([batch_size, self._stoch]).to(self._device), + std=torch.zeros([batch_size, self._stoch]).to(self._device), + stoch=torch.zeros([batch_size, self._stoch]).to(self._device), + deter=deter, + ) + return state + + def observe(self, embed, action, state=None): + swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) # 交换前两维 + if state is None: + state = self.initial(action.shape[0]) # {logit, stoch, deter} + # (batch, time, ch) -> (time, batch, ch) + embed, action = swap(embed), swap(action) + post, prior = static_scan( + lambda prev_state, prev_act, embed: self.obs_step(prev_state[0], prev_act, embed), + (action, embed), + (state, state), + ) + + # (time, batch, stoch, discrete_num) -> (batch, time, stoch, discrete_num) + post = {k: swap(v) for k, v in post.items()} + prior = {k: swap(v) for k, v in prior.items()} + return post, prior + + def imagine(self, action, state=None): + swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(action.shape[0]) + assert isinstance(state, dict), state + action = action + action = swap(action) + prior = static_scan(self.img_step, [action], state) + prior = prior[0] + prior = {k: swap(v) for k, v in prior.items()} + return prior + + def get_feat(self, state): + stoch = state["stoch"] + if self._discrete: + shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] + stoch = stoch.reshape(shape) + return torch.cat([stoch, state["deter"]], -1) + + def get_dist(self, state, dtype=None): + if self._discrete: + logit = state["logit"] + dist = torchd.independent.Independent(OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1) + else: + mean, std = state["mean"], state["std"] + dist = ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), 1)) + return dist + + def obs_step(self, prev_state, prev_action, embed, sample=True): + # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) + # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs + prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() + prior = self.img_step(prev_state, prev_action, None, sample) + if self._shared: + post = self.img_step(prev_state, prev_action, embed, sample) + else: + if self._temp_post: + x = torch.cat([prior["deter"], embed], -1) + else: + x = embed + # (batch_size, prior_deter + embed) -> (batch_size, hidden) + x = self._obs_out_layers(x) + # (batch_size, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("obs", x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + post = {"stoch": stoch, "deter": prior["deter"], **stats} + return post, prior + + # this is used for making future image + def img_step(self, prev_state, prev_action, embed=None, sample=True): + # (batch, stoch, discrete_num) + prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() + prev_stoch = prev_state["stoch"] + if self._discrete: + shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] + # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) + prev_stoch = prev_stoch.reshape(shape) + if self._shared: + if embed is None: + shape = list(prev_action.shape[:-1]) + [self._embed] + embed = torch.zeros(shape) + # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed) + x = torch.cat([prev_stoch, prev_action, embed], -1) + else: + x = torch.cat([prev_stoch, prev_action], -1) + # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) + x = self._inp_layers(x) + for _ in range(self._rec_depth): # rec depth is not correctly implemented + deter = prev_state["deter"] + # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter) + x, deter = self._cell(x, [deter]) + deter = deter[0] # Keras wraps the state in a list. + # (batch, deter) -> (batch, hidden) + x = self._img_out_layers(x) + # (batch, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("ims", x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + prior = {"stoch": stoch, "deter": deter, **stats} # {stoch, deter, logit} + return prior + + def _suff_stats_layer(self, name, x): + if self._discrete: + if name == "ims": + x = self._ims_stat_layer(x) + elif name == "obs": + x = self._obs_stat_layer(x) + else: + raise NotImplementedError + logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) + return {"logit": logit} + else: + if name == "ims": + x = self._ims_stat_layer(x) + elif name == "obs": + x = self._obs_stat_layer(x) + else: + raise NotImplementedError + mean, std = torch.split(x, [self._stoch] * 2, -1) + mean = { + "none": lambda: mean, + "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0), + }[self._mean_act]() + std = { + "softplus": lambda: torch.softplus(std), + "abs": lambda: torch.abs(std + 1), + "sigmoid": lambda: torch.sigmoid(std), + "sigmoid2": lambda: 2 * torch.sigmoid(std / 2), + }[self._std_act]() + std = std + self._min_std + return {"mean": mean, "std": std} + + def kl_loss(self, post, prior, forward, free, lscale, rscale): + kld = torchd.kl.kl_divergence + dist = lambda x: self.get_dist(x) + sg = lambda x: {k: v.detach() for k, v in x.items()} + # forward == false -> (post, prior) + lhs, rhs = (prior, post) if forward else (post, prior) + + # forward == false -> Lrep + value_lhs = value = kld( + dist(lhs) if self._discrete else dist(lhs)._dist, + dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, + ) + # forward == false -> Ldyn + value_rhs = kld( + dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, + dist(rhs) if self._discrete else dist(rhs)._dist, + ) + loss_lhs = torch.clip(torch.mean(value_lhs), min=free) + loss_rhs = torch.clip(torch.mean(value_rhs), min=free) + loss = lscale * loss_lhs + rscale * loss_rhs + + return loss, value, loss_lhs, loss_rhs + + +class ConvDecoder(nn.Module): + + def __init__( + self, + inp_depth, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter + depth=32, + act=nn.ELU, + norm=nn.LayerNorm, + shape=(3, 64, 64), + kernels=(3, 3, 3, 3), + outscale=1.0, + ): + super(ConvDecoder, self).__init__() + self._inp_depth = inp_depth + self._act = act + self._norm = norm + self._depth = depth + self._shape = shape + self._kernels = kernels + self._embed_size = ((64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)) + + self._linear_layer = nn.Linear(inp_depth, self._embed_size) + inp_dim = self._embed_size // 16 # 除以最后的4*4 feature map来得到channel数 + + layers = [] + h, w = 4, 4 + for i, kernel in enumerate(self._kernels): + depth = self._embed_size // 16 // (2 ** (i + 1)) + act = self._act + bias = False + initializer = weight_init + if i == len(self._kernels) - 1: + depth = self._shape[0] + act = False + bias = True + norm = False + initializer = uniform_weight_init(outscale) + + if i != 0: + inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth + pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1) + pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1) + layers.append( + nn.ConvTranspose2d( + inp_dim, + depth, + kernel, + 2, + padding=(pad_h, pad_w), + output_padding=(outpad_h, outpad_w), + bias=bias, + ) + ) + if norm: + layers.append(DreamerLayerNorm(depth)) + if act: + layers.append(act()) + [m.apply(initializer) for m in layers[-3:]] + h, w = h * 2, w * 2 + + self.layers = nn.Sequential(*layers) + + def calc_same_pad(self, k, s, d): + val = d * (k - 1) - s + 1 + pad = math.ceil(val / 2) + outpad = pad * 2 - val + return pad, outpad + + def __call__(self, features, dtype=None): + x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter] + x = x.reshape([-1, 4, 4, self._embed_size // 16]) + x = x.permute(0, 3, 1, 2) + x = self.layers(x) + mean = x.reshape(features.shape[:-1] + self._shape) + mean = mean.permute(0, 1, 3, 4, 2) + return SymlogDist(mean) class DenseHead(nn.Module): @@ -70,3 +428,132 @@ def forward(self, features, dtype=None): if self._dist == "twohot_symlog": return TwoHotDistSymlog(logits=mean) raise NotImplementedError(self._dist) + + +class ActionHead(nn.Module): + + def __init__( + self, + inp_dim, + size, + layers, + units, + act=nn.ELU, + norm=nn.LayerNorm, + dist="trunc_normal", + init_std=0.0, + min_std=0.1, + max_std=1.0, + temp=0.1, + outscale=1.0, + unimix_ratio=0.01, + ): + super(ActionHead, self).__init__() + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = act + self._norm = norm + self._min_std = min_std + self._max_std = max_std + self._init_std = init_std + self._unimix_ratio = unimix_ratio + self._temp = temp() if callable(temp) else temp + + pre_layers = [] + for index in range(self._layers): + pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) + pre_layers.append(norm(self._units, eps=1e-03)) + pre_layers.append(act()) + if index == 0: + inp_dim = self._units + self._pre_layers = nn.Sequential(*pre_layers) + self._pre_layers.apply(weight_init) + + if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: + self._dist_layer = nn.Linear(self._units, 2 * self._size) + self._dist_layer.apply(uniform_weight_init(outscale)) + + elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: + self._dist_layer = nn.Linear(self._units, self._size) + self._dist_layer.apply(uniform_weight_init(outscale)) + + def __call__(self, features, dtype=None): + x = features + x = self._pre_layers(x) + if self._dist == "tanh_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = torch.tanh(mean) + std = F.softplus(std + self._init_std) + self._min_std + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) + dist = torchd.independent.Independent(dist, 1) + dist = SampleDist(dist) + elif self._dist == "tanh_normal_5": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = 5 * torch.tanh(mean / 5) + std = F.softplus(std + 5) + 5 + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) + dist = torchd.independent.Independent(dist, 1) + dist = SampleDist(dist) + elif self._dist == "normal": + x = self._dist_layer(x) + mean, std = torch.split(x, [self._size] * 2, -1) + std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std + dist = torchd.normal.Normal(torch.tanh(mean), std) + dist = ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "normal_1": + x = self._dist_layer(x) + dist = torchd.normal.Normal(mean, 1) + dist = ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "trunc_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, [self._size] * 2, -1) + mean = torch.tanh(mean) + std = 2 * torch.sigmoid(std / 2) + self._min_std + dist = SafeTruncatedNormal(mean, std, -1, 1) + dist = ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "onehot": + x = self._dist_layer(x) + dist = OneHotDist(x, unimix_ratio=self._unimix_ratio) + elif self._dist == "onehot_gumble": + x = self._dist_layer(x) + temp = self._temp + dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) + else: + raise NotImplementedError(self._dist) + return dist + + +class GRUCell(nn.Module): + + def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): + super(GRUCell, self).__init__() + self._inp_size = inp_size # hidden + self._size = size # deter + self._act = act + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) + if norm: + self._norm = nn.LayerNorm(3 * size, eps=1e-03) + + @property + def state_size(self): + return self._size + + def forward(self, inputs, state): + state = state[0] # Keras wraps the state in a list. + parts = self._layer(torch.cat([inputs, state], -1)) + if self._norm: + parts = self._norm(parts) + reset, cand, update = torch.split(parts, [self._size] * 3, -1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, [output] diff --git a/ding/world_model/utils.py b/ding/world_model/utils.py index 2f9f10ffdf..06d660af25 100644 --- a/ding/world_model/utils.py +++ b/ding/world_model/utils.py @@ -236,6 +236,77 @@ def mode(self): return self.mean +class SafeTruncatedNormal(torchd.normal.Normal): + + def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + super().__init__(loc, scale) + self._low = low + self._high = high + self._clip = clip + self._mult = mult + + def sample(self, sample_shape): + event = super().sample(sample_shape) + if self._clip: + clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) + event = event - event.detach() + clipped.detach() + if self._mult: + event *= self._mult + return event + + +class TanhBijector(torchd.Transform): + + def __init__(self, validate_args=False, name='tanh'): + super().__init__() + + def _forward(self, x): + return torch.tanh(x) + + def _inverse(self, y): + y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) + y = torch.atanh(y) + return y + + def _forward_log_det_jacobian(self, x): + log2 = torch.math.log(2.0) + return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) + + +def static_scan(fn, inputs, start): + last = start # {logit:[batch_size, self._stoch, self._discrete], stoch:[batch_size, self._stoch, self._discrete], deter:[batch_size, self._deter]} + indices = range(inputs[0].shape[0]) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096)) + last = fn(last, *inp(index)) # post, prior + if flag: + if type(last) == type({}): + outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} + else: + outputs = [] + for _last in last: + if type(_last) == type({}): + outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) + else: + outputs.append(_last.clone().unsqueeze(0)) + flag = False + else: + if type(last) == type({}): + for key in last.keys(): + outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) + else: + for j in range(len(outputs)): + if type(last[j]) == type({}): + for key in last[j].keys(): + outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) + else: + outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) + if type(last) == type({}): + outputs = [outputs] + return outputs + + def weight_init(m): if isinstance(m, nn.Linear): in_num = m.in_features diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py new file mode 100644 index 0000000000..aa22f0bb2b --- /dev/null +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -0,0 +1,92 @@ +from easydict import EasyDict + +from ding.entry import serial_pipeline_dreamer + +cuda = False + +cartpole_balance_dreamer_config = dict( + exp_name='dmc2gym_cartpole_balance_dreamer', + env=dict( + env_id='dmc2gym_cartpole_balance', + domain_name='cartpole', + task_name='balance', + frame_skip=4, + warp_frame=True, + scale=True, + clip_rewards=False, + frame_stack=3, + from_pixels=True, + collector_env_num=1, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=100000, + ), + policy=dict( + cuda=cuda, + # it is better to put random_collect_size in policy.other + random_collect_size=5000, + model=dict( + obs_shape=(3, 64, 64), + action_shape=1, + ), + learn=dict( + lambda_=0.95, + learning_rate=0.001, + batch_size=256, + imag_sample=True, + discount=0.997, + reward_EMA=True, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(evaluator=dict(eval_freq=10000, )), # w.r.t envstep + other=dict( + # environment buffer + replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), + ), + ), + world_model=dict( + pretrain=100 + #eval_freq=250, # w.r.t envstep + train_freq=2, # w.r.t envstep + cuda=cuda, + model=dict( + #elite_size=5, + state_size=(3, 84, 84), # has to be specified + action_size=1, # has to be specified + reward_size=1, + #hidden_size=200, + #use_decay=True, + batch_size=256, + #holdout_ratio=0.1, + #max_epochs_since_update=5, + #deterministic_rollout=True, + ), + ), +) + +cartpole_balance_dreamer_config = EasyDict(cartpole_balance_dreamer_config) + +cartpole_balance_create_config = dict( + env=dict( + type='dmc2gym', + import_names=['dizoo.dmc2gym.envs.dmc2gym_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='dreamer', + import_names=['ding.policy.mbpolicy.dreamer'], + ), + replay_buffer=dict(type='naive', ), + world_model=dict( + type='dreamer', + import_names=['ding.world_model.dreamer'], + ), +) +cartpole_balance_create_config = EasyDict(cartpole_balance_create_config) + +if __name__ == '__main__': + serial_pipeline_dreamer((cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=1000000) \ No newline at end of file diff --git a/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py b/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py deleted file mode 100644 index 1ee4ac165b..0000000000 --- a/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py +++ /dev/null @@ -1,110 +0,0 @@ -from easydict import EasyDict - -from ding.entry import serial_pipeline_dream - -# environment hypo -env_id = 'HalfCheetah-v3' -obs_shape = 17 -action_shape = 6 - -# gpu -cuda = True - -main_config = dict( - exp_name='halfcheetach_mbsac_mbpo_seed0', - env=dict( - env_id=env_id, - norm_obs=dict(use_norm=False, ), - norm_reward=dict(use_norm=False, ), - collector_env_num=1, - evaluator_env_num=8, - n_evaluator_episode=8, - stop_value=100000, - ), - policy=dict( - cuda=cuda, - # it is better to put random_collect_size in policy.other - random_collect_size=10000, - model=dict( - obs_shape=obs_shape, - action_shape=action_shape, - twin_critic=True, - action_space='reparameterization', - actor_head_hidden_size=256, - critic_head_hidden_size=256, - ), - learn=dict( - lambda_=0.8, - sample_state=False, - update_per_collect=40, - batch_size=256, - learning_rate_q=3e-4, - learning_rate_policy=3e-4, - learning_rate_alpha=3e-4, - ignore_done=False, - target_theta=0.005, - discount_factor=0.99, - alpha=0.2, - reparameterization=True, - auto_alpha=False, - ), - collect=dict( - n_sample=1, - unroll_len=1, - ), - command=dict(), - eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep - other=dict( - # environment buffer - replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), - ), - ), - world_model=dict( - eval_freq=250, # w.r.t envstep - train_freq=250, # w.r.t envstep - cuda=cuda, - rollout_length_scheduler=dict( - type='linear', - rollout_start_step=20000, - rollout_end_step=40000, - rollout_length_min=1, - rollout_length_max=3, - ), - model=dict( - ensemble_size=7, - elite_size=5, - state_size=obs_shape, # has to be specified - action_size=action_shape, # has to be specified - reward_size=1, - hidden_size=200, - use_decay=True, - batch_size=256, - holdout_ratio=0.1, - max_epochs_since_update=5, - deterministic_rollout=True, - ), - ), -) - -main_config = EasyDict(main_config) - -create_config = dict( - env=dict( - type='mbmujoco', - import_names=['dizoo.mujoco.envs.mujoco_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='mbsac', - import_names=['ding.policy.mbpolicy.mbsac'], - ), - replay_buffer=dict(type='naive', ), - world_model=dict( - type='mbpo', - import_names=['ding.world_model.mbpo'], - ), -) -create_config = EasyDict(create_config) - -if __name__ == '__main__': - serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=100000) From 9c0837ead3d79ad40a614fa238e9cfb6999d10b4 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Thu, 15 Jun 2023 18:07:51 +0800 Subject: [PATCH 06/14] add dreamervac --- ding/entry/__init__.py | 2 +- ding/entry/serial_entry_mbrl.py | 38 ++-- ding/model/common/tests/test_encoder.py | 4 - ding/model/template/__init__.py | 2 +- ding/model/template/vac.py | 183 +++++------------- ding/policy/mbpolicy/dreamer.py | 63 ++---- ding/policy/mbpolicy/utils.py | 139 ++++++------- ding/world_model/dreamer.py | 4 +- .../dmc2gym/config/dmc2gym_dreamer_config.py | 6 +- .../mbrl/halfcheetah_mbsac_mbpo_config.py | 110 +++++++++++ 10 files changed, 262 insertions(+), 289 deletions(-) create mode 100644 dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index e0501b12db..c944209786 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -25,6 +25,6 @@ from .serial_entry_preference_based_irl_onpolicy \ import serial_pipeline_preference_based_irl_onpolicy from .application_entry_drex_collect_data import drex_collecting_data -from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream +from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer from .serial_entry_bco import serial_pipeline_bco from .serial_entry_pc import serial_pipeline_pc diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index e8dad68759..7e27e54777 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -285,26 +285,26 @@ def serial_pipeline_dreamer( stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break - for i in range(eval_every): - # train world model and fill imagination buffer - steps = ( - cfg.world_model.pretrain - if world_model.should_pretrain() - else int(world_model.should_train(collector.envstep)) - ) - for _ in range(steps): - batch_size = learner.policy.get_attribute('batch_size') - post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size) - - start = post - - learner.train( - start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep) - ) + + # train world model and fill imagination buffer + steps = ( + cfg.world_model.pretrain + if world_model.should_pretrain() + else int(world_model.should_train(collector.envstep)) + ) + for _ in range(steps): + batch_size = learner.policy.get_attribute('batch_size') + post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size) - # fill environment buffer - data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) - env_buffer.push(data, cur_collector_envstep=collector.envstep) + start = post + + learner.train( + start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep) + ) + + # fill environment buffer + data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + env_buffer.push(data, cur_collector_envstep=collector.envstep) if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: break diff --git a/ding/model/common/tests/test_encoder.py b/ding/model/common/tests/test_encoder.py index ef31550690..cd8a5bf752 100644 --- a/ding/model/common/tests/test_encoder.py +++ b/ding/model/common/tests/test_encoder.py @@ -61,7 +61,3 @@ def test_impalaconv_encoder(self): outputs = model(inputs) self.output_check(model, outputs) assert outputs.shape == (B, 256) - - -a = TestEncoder() -a.test_dreamer_conv_encoder() diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 55bd468034..7ed29480bd 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -2,7 +2,7 @@ from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ from .qac import QAC, DiscreteQAC from .pdqn import PDQN -from .vac import VAC +from .vac import VAC, DREAMERVAC from .bc import DiscreteBC, ContinuousBC from .pg import PG # algorithm-specific diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 9b99b70908..84c612a7d2 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -6,6 +6,7 @@ from ding.utils import SequenceType, squeeze, MODEL_REGISTRY from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ FCEncoder, ConvEncoder, IMPALAConvEncoder +from ding.world_model.model.networks import ActionHead, DenseHead @MODEL_REGISTRY.register('vac') @@ -369,23 +370,23 @@ class DREAMERVAC(nn.Module): mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( - self, - obs_shape: Union[int, SequenceType], - action_shape: Union[int, SequenceType, EasyDict], - action_space: str = 'discrete', - share_encoder: bool = True, - encoder_hidden_size_list: SequenceType = [128, 128, 64], - actor_head_hidden_size: int = 64, - actor_head_layer_num: int = 1, - critic_head_hidden_size: int = 64, - critic_head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None, - sigma_type: Optional[str] = 'independent', - fixed_sigma_value: Optional[int] = 0.3, - bound_type: Optional[str] = None, - encoder: Optional[torch.nn.Module] = None, - impala_cnn_encoder: bool = False, + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType, EasyDict], + dyn_stoch=32, + dyn_deter=512, + dyn_discrete=32, + actor_layers=2, + value_layers=2, + units=512, + act='SiLU', + norm='LayerNorm', + actor_dist='normal', + actor_init_std=1.0, + actor_min_std=0.1, + actor_max_std=1.0, + actor_temp=0.1, + action_unimix_ratio=0.01, ) -> None: r""" Overview: @@ -408,129 +409,37 @@ def __init__( - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` """ - super(VAC, self).__init__() + super(DREAMERVAC, self).__init__() obs_shape: int = squeeze(obs_shape) action_shape = squeeze(action_shape) self.obs_shape, self.action_shape = obs_shape, action_shape - self.impala_cnn_encoder = impala_cnn_encoder - self.share_encoder = share_encoder - - # Encoder Type - def new_encoder(outsize): - if impala_cnn_encoder: - return IMPALAConvEncoder(obs_shape=obs_shape, channels=encoder_hidden_size_list, outsize=outsize) - else: - if isinstance(obs_shape, int) or len(obs_shape) == 1: - return FCEncoder( - obs_shape=obs_shape, - hidden_size_list=encoder_hidden_size_list, - activation=activation, - norm_type=norm_type - ) - elif len(obs_shape) == 3: - return ConvEncoder( - obs_shape=obs_shape, - hidden_size_list=encoder_hidden_size_list, - activation=activation, - norm_type=norm_type - ) - else: - raise RuntimeError( - "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". - format(obs_shape) - ) - if self.share_encoder: - assert actor_head_hidden_size == critic_head_hidden_size, \ - "actor and critic network head should have same size." - if encoder: - if isinstance(encoder, torch.nn.Module): - self.encoder = encoder - else: - raise ValueError("illegal encoder instance.") - else: - self.encoder = new_encoder(actor_head_hidden_size) + if dyn_discrete: + feat_size = dyn_stoch * dyn_discrete + dyn_deter else: - if encoder: - if isinstance(encoder, torch.nn.Module): - self.actor_encoder = encoder - self.critic_encoder = deepcopy(encoder) - else: - raise ValueError("illegal encoder instance.") - else: - self.actor_encoder = new_encoder(actor_head_hidden_size) - self.critic_encoder = new_encoder(critic_head_hidden_size) - - # Head Type - self.critic_head = RegressionHead( - critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type + feat_size = dyn_stoch + dyn_deter + self.actor = ActionHead( + feat_size, # pytorch version + action_shape, + actor_layers, + units, + act, + norm, + actor_dist, + actor_init_std, + actor_min_std, + actor_max_std, + actor_temp, + outscale=1.0, + unimix_ratio=action_unimix_ratio, + ) # action_dist -> action_disc? + self.critic = DenseHead( + feat_size, # pytorch version + (255, ), + value_layers, + units, + act, + norm, + 'twohot_symlog', + outscale=0.0, ) - self.action_space = action_space - assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space - if self.action_space == 'continuous': - self.multi_head = False - self.actor_head = ReparameterizationHead( - actor_head_hidden_size, - action_shape, - actor_head_layer_num, - sigma_type=sigma_type, - activation=activation, - norm_type=norm_type, - bound_type=bound_type - ) - elif self.action_space == 'discrete': - actor_head_cls = DiscreteHead - multi_head = not isinstance(action_shape, int) - self.multi_head = multi_head - if multi_head: - self.actor_head = MultiHead( - actor_head_cls, - actor_head_hidden_size, - action_shape, - layer_num=actor_head_layer_num, - activation=activation, - norm_type=norm_type - ) - else: - self.actor_head = actor_head_cls( - actor_head_hidden_size, - action_shape, - actor_head_layer_num, - activation=activation, - norm_type=norm_type - ) - elif self.action_space == 'hybrid': # HPPO - # hybrid action space: action_type(discrete) + action_args(continuous), - # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} - action_shape.action_args_shape = squeeze(action_shape.action_args_shape) - action_shape.action_type_shape = squeeze(action_shape.action_type_shape) - actor_action_args = ReparameterizationHead( - actor_head_hidden_size, - action_shape.action_args_shape, - actor_head_layer_num, - sigma_type=sigma_type, - fixed_sigma_value=fixed_sigma_value, - activation=activation, - norm_type=norm_type, - bound_type=bound_type, - ) - actor_action_type = DiscreteHead( - actor_head_hidden_size, - action_shape.action_type_shape, - actor_head_layer_num, - activation=activation, - norm_type=norm_type, - ) - self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) - - # must use list, not nn.ModuleList - if self.share_encoder: - self.actor = [self.encoder, self.actor_head] - self.critic = [self.encoder, self.critic_head] - else: - self.actor = [self.actor_encoder, self.actor_head] - self.critic = [self.critic_encoder, self.critic_head] - # Convenient for calling some apis (e.g. self.critic.parameters()), - # but may cause misunderstanding when `print(self)` - self.actor = nn.ModuleList(self.actor) - self.critic = nn.ModuleList(self.critic) \ No newline at end of file diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 4b31688f45..a765b30ea5 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -1,11 +1,10 @@ -from typing import Dict, Any, List +from typing import Dict, Any, List, Tuple import torch from torch import nn from copy import deepcopy from ding.torch_utils import Adam, to_device from ding.utils import POLICY_REGISTRY, deep_merge_dicts from ding.policy import Policy -from ding.rl_utils import generalized_lambda_returns from ding.model import model_wrap from ding.policy.common_utils import default_preprocess_learn @@ -40,7 +39,7 @@ class DREAMERPolicy(Policy): def default_model(self) -> Tuple[str, List[str]]: return 'dreamervac', ['ding.model.template.vac'] - + def _init_learn(self) -> None: r""" Overview: @@ -53,11 +52,11 @@ def _init_learn(self) -> None: self._critic = self._model.critic self._actor = self._model.actor - + if self._cfg.learn.slow_value_target: self._slow_value = deepcopy(self._critic) self._updates = 0 - + # Optimizer self._optimizer_value = Adam( self._critic.parameters(), @@ -67,7 +66,7 @@ def _init_learn(self) -> None: self._actor.parameters(), lr=self._cfg.learn.learning_rate, ) - + self._learn_model = model_wrap(self._model, wrapper_name='base') self._learn_model.reset() @@ -76,22 +75,20 @@ def _init_learn(self) -> None: if self._cfg.learn.reward_EMA: self.reward_ema = RewardEMA(device=self._device) - def _forward_learn(self, start: dict, repeats=None, world_model=None, envstep) -> Dict[str, Any]: + def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: # log dict log_vars = {} + self._learn_model.train() world_model.requires_grad_(requires_grad=False) self._actor.requires_grad_(requires_grad=True) # start is dict of {stoch, deter, logit} if self._cuda: start = to_device(start, self._device) - self._learn_model.train() - self._target_model.train() - # train self._actor imag_feat, imag_state, imag_action = imagine( - self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon, repeats - ) + self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon + ) reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode() actor_ent = self._actor(imag_feat).entropy() state_ent = world_model.dynamics.get_dist(imag_state).entropy() @@ -125,9 +122,7 @@ def _forward_learn(self, start: dict, repeats=None, world_model=None, envstep) - value_loss = -value.log_prob(target.detach()) slow_target = self._slow_value(value_input[:-1].detach()) if self._cfg.learn.slow_value_target: - value_loss = value_loss - value.log_prob( - slow_target.mode().detach() - ) + value_loss = value_loss - value.log_prob(slow_target.mode().detach()) if self._cfg.learn.value_decay: value_loss += self._cfg.learn.value_decay * value.mode() # (time, batch, 1), (time, batch, 1) -> (1,) @@ -143,14 +138,14 @@ def _forward_learn(self, start: dict, repeats=None, world_model=None, envstep) - # actor-critic update # ==================== self._model.requires_grad_(requires_grad=True) - + loss_dict = { 'critic_loss': value_loss, 'actor_loss': actor_loss, } norm_dict = self._update(loss_dict) - + self._model.requires_grad_(requires_grad=False) # ============= # after update @@ -184,33 +179,9 @@ def _monitor_vars_learn(self) -> List[str]: - vars (:obj:`List[str]`): Variables' name list. """ return [ - 'normed_target_mean', - 'normed_target_std', - 'normed_target_min', - 'normed_target_max', - 'EMA_005', - 'EMA_095', - 'actor_entropy', - 'actor_state_entropy', - 'value_mean', - 'value_std', - 'value_min', - 'value_max', - 'target_mean', - 'target_std', - 'target_min', - 'target_max', - 'imag_reward_mean', - 'imag_reward_std', - 'imag_reward_min', - 'imag_reward_max', - 'imag_action_mean', - 'imag_action_std', - 'imag_action_min', - 'imag_action_max', - 'actor_ent', - 'actor_loss', - 'critic_loss', - 'actor_grad_norm', - 'critic_grad_norm' + 'normed_target_mean', 'normed_target_std', 'normed_target_min', 'normed_target_max', 'EMA_005', 'EMA_095', + 'actor_entropy', 'actor_state_entropy', 'value_mean', 'value_std', 'value_min', 'value_max', 'target_mean', + 'target_std', 'target_min', 'target_max', 'imag_reward_mean', 'imag_reward_std', 'imag_reward_min', + 'imag_reward_max', 'imag_action_mean', 'imag_action_std', 'imag_action_min', 'imag_action_max', 'actor_ent', + 'actor_loss', 'critic_loss', 'actor_grad_norm', 'critic_grad_norm' ] diff --git a/ding/policy/mbpolicy/utils.py b/ding/policy/mbpolicy/utils.py index 985e7dfae2..8fdea26956 100644 --- a/ding/policy/mbpolicy/utils.py +++ b/ding/policy/mbpolicy/utils.py @@ -2,8 +2,10 @@ import torch from torch import Tensor from ding.torch_utils import fold_batch, unfold_batch +from ding.rl_utils import generalized_lambda_returns from ding.world_model.utils import static_scan + def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor], Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ @@ -40,96 +42,79 @@ def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, T def imagine(cfg, world_model, start, actor, horizon, repeats=None): dynamics = world_model.dynamics - if repeats: - raise NotImplemented("repeats is not implemented in this version") flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) start = {k: flatten(v) for k, v in start.items()} def step(prev, _): state, _, _ = prev feat = dynamics.get_feat(state) - inp = feat.detach() if self._stop_grad_actor else feat + inp = feat.detach() action = actor(inp).sample() succ = dynamics.img_step(state, action, sample=cfg.imag_sample) return succ, feat, action - succ, feats, actions = static_scan( - step, [torch.arange(horizon)], (start, None, None) - ) + succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None)) states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} - if repeats: - raise NotImplemented("repeats is not implemented in this version") return feats, states, actions -def compute_target( - cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent -): +def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent): if "cont" in world_model.heads: - inp = self._world_model.dynamics.get_feat(imag_state) - discount = cfg.discount * self._world_model.heads["cont"](inp).mean + inp = world_model.dynamics.get_feat(imag_state) + discount = cfg.discount * world_model.heads["cont"](inp).mean else: discount = cfg.discount * torch.ones_like(reward) - + value = critic(imag_feat).mode() - # value(imag_horizon, 16*64, ch) + # value(imag_horizon, 16*64, 1) # action(imag_horizon, 16*64, ch) - # discount(imag_horizon, 16*64, ch) - target = tools.lambda_return( - reward[:-1], - value[:-1], - discount[:-1], - bootstrap=value[-1], - lambda_=cfg.lambda_, - axis=0, - ) - weights = torch.cumprod( - torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0 - ).detach() + # discount(imag_horizon, 16*64, 1) + target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_) + weights = torch.cumprod(torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach() return target, weights, value[:-1] def compute_actor_loss( - cfg, - actor, - reward_ema - imag_feat, - imag_state, - imag_action, - target, - actor_ent, - state_ent, - weights, - base, - ): - metrics = {} - inp = imag_feat.detach() - policy = actor(inp) - actor_ent = policy.entropy() - # Q-val for actor is not transformed using symlog - target = torch.stack(target, dim=1) - if cfg.reward_EMA: - offset, scale = reward_ema(target) - normed_target = (target - offset) / scale - normed_base = (base - offset) / scale - adv = normed_target - normed_base - metrics.update(tools.tensorstats(normed_target, "normed_target")) - values = reward_ema.values - metrics["EMA_005"] = values[0].detach().cpu().numpy() - metrics["EMA_095"] = values[1].detach().cpu().numpy() - - actor_target = adv - if cfg.actor_entropy > 0: - actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] - actor_target += actor_entropy - metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy() - if cfg.actor_state_entropy > 0: - state_entropy = cfg.actor_state_entropy * state_ent[:-1] - actor_target += state_entropy - metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy() - actor_loss = -torch.mean(weights[:-1] * actor_target) - return actor_loss, metrics + cfg, + actor, + reward_ema, + imag_feat, + imag_state, + imag_action, + target, + actor_ent, + state_ent, + weights, + base, +): + metrics = {} + inp = imag_feat.detach() + policy = actor(inp) + actor_ent = policy.entropy() + # Q-val for actor is not transformed using symlog + # target = torch.stack(target, dim=1) + if cfg.reward_EMA: + offset, scale = reward_ema(target) + normed_target = (target - offset) / scale + normed_base = (base - offset) / scale + adv = normed_target - normed_base + metrics.update(tensorstats(normed_target, "normed_target")) + values = reward_ema.values + metrics["EMA_005"] = values[0].detach().cpu().numpy() + metrics["EMA_095"] = values[1].detach().cpu().numpy() + + actor_target = adv + if cfg.actor_entropy > 0: + actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] + actor_target += actor_entropy + metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy() + if cfg.actor_state_entropy > 0: + state_entropy = cfg.actor_state_entropy * state_ent[:-1] + actor_target += state_entropy + metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy() + actor_loss = -torch.mean(weights[:-1] * actor_target) + return actor_loss, metrics class RewardEMA(object): @@ -137,7 +122,7 @@ class RewardEMA(object): def __init__(self, device, alpha=1e-2): self.device = device - self.values = torch.zeros((2,)).to(device) + self.values = torch.zeros((2, )).to(device) self.alpha = alpha self.range = torch.tensor([0.05, 0.95]).to(device) @@ -148,15 +133,15 @@ def __call__(self, x): scale = torch.clip(self.values[1] - self.values[0], min=1.0) offset = self.values[0] return offset.detach(), scale.detach() - + def tensorstats(tensor, prefix=None): - metrics = { - 'mean': torch.mean(tensor).detach().cpu().numpy(), - 'std': torch.std(tensor).detach().cpu().numpy(), - 'min': torch.min(tensor).detach().cpu().numpy(), - 'max': torch.max(tensor).detach().cpu().numpy(), - } - if prefix: - metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} - return metrics \ No newline at end of file + metrics = { + 'mean': torch.mean(tensor).detach().cpu().numpy(), + 'std': torch.std(tensor).detach().cpu().numpy(), + 'min': torch.min(tensor).detach().cpu().numpy(), + 'max': torch.max(tensor).detach().cpu().numpy(), + } + if prefix: + metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} + return metrics diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py index fe1de29829..031e2f390c 100644 --- a/ding/world_model/dreamer.py +++ b/ding/world_model/dreamer.py @@ -14,6 +14,8 @@ @WORLD_MODEL_REGISTRY.register('dreamer') class DREAMERWorldModel(WorldModel, nn.Module): config = dict( + pretrain=100, + train_freq=2, model=dict( state_size=None, action_size=None, @@ -38,8 +40,6 @@ class DREAMERWorldModel(WorldModel, nn.Module): discount_layers=2, value_layers=2, actor_layers=2, - act='SiLU', - norm='LayerNorm', cnn_depth=32, encoder_kernels=[4, 4, 4, 4], decoder_kernels=[4, 4, 4, 4], diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index aa22f0bb2b..fb2c5d97a7 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -28,6 +28,7 @@ model=dict( obs_shape=(3, 64, 64), action_shape=1, + actor_dist = 'normal', ), learn=dict( lambda_=0.95, @@ -49,13 +50,14 @@ ), ), world_model=dict( - pretrain=100 + #eval_every=10, + pretrain=100, #eval_freq=250, # w.r.t envstep train_freq=2, # w.r.t envstep cuda=cuda, model=dict( #elite_size=5, - state_size=(3, 84, 84), # has to be specified + state_size=(3, 64, 64), # has to be specified action_size=1, # has to be specified reward_size=1, #hidden_size=200, diff --git a/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py b/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py new file mode 100644 index 0000000000..1ee4ac165b --- /dev/null +++ b/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py @@ -0,0 +1,110 @@ +from easydict import EasyDict + +from ding.entry import serial_pipeline_dream + +# environment hypo +env_id = 'HalfCheetah-v3' +obs_shape = 17 +action_shape = 6 + +# gpu +cuda = True + +main_config = dict( + exp_name='halfcheetach_mbsac_mbpo_seed0', + env=dict( + env_id=env_id, + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=1, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=100000, + ), + policy=dict( + cuda=cuda, + # it is better to put random_collect_size in policy.other + random_collect_size=10000, + model=dict( + obs_shape=obs_shape, + action_shape=action_shape, + twin_critic=True, + action_space='reparameterization', + actor_head_hidden_size=256, + critic_head_hidden_size=256, + ), + learn=dict( + lambda_=0.8, + sample_state=False, + update_per_collect=40, + batch_size=256, + learning_rate_q=3e-4, + learning_rate_policy=3e-4, + learning_rate_alpha=3e-4, + ignore_done=False, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep + other=dict( + # environment buffer + replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), + ), + ), + world_model=dict( + eval_freq=250, # w.r.t envstep + train_freq=250, # w.r.t envstep + cuda=cuda, + rollout_length_scheduler=dict( + type='linear', + rollout_start_step=20000, + rollout_end_step=40000, + rollout_length_min=1, + rollout_length_max=3, + ), + model=dict( + ensemble_size=7, + elite_size=5, + state_size=obs_shape, # has to be specified + action_size=action_shape, # has to be specified + reward_size=1, + hidden_size=200, + use_decay=True, + batch_size=256, + holdout_ratio=0.1, + max_epochs_since_update=5, + deterministic_rollout=True, + ), + ), +) + +main_config = EasyDict(main_config) + +create_config = dict( + env=dict( + type='mbmujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='mbsac', + import_names=['ding.policy.mbpolicy.mbsac'], + ), + replay_buffer=dict(type='naive', ), + world_model=dict( + type='mbpo', + import_names=['ding.world_model.mbpo'], + ), +) +create_config = EasyDict(create_config) + +if __name__ == '__main__': + serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=100000) From b61ac33502339487751efd089181ea607f693b32 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Mon, 19 Jun 2023 14:51:56 +0800 Subject: [PATCH 07/14] relocate world model network to torch_utils/dreamer --- ding/entry/serial_entry_mbrl.py | 4 +- ding/model/common/encoder.py | 3 +- ding/model/template/vac.py | 6 +- ding/policy/command_mode_policy_instance.py | 5 + ding/policy/mbpolicy/dreamer.py | 132 ++++- ding/policy/mbpolicy/utils.py | 2 +- ding/torch_utils/network/__init__.py | 2 +- ding/torch_utils/network/dreamer.py | 496 +++++++++++++++++- .../torch_utils/network/tests/test_dreamer.py | 73 +++ .../collector/interaction_serial_evaluator.py | 3 +- ding/world_model/dreamer.py | 41 +- ding/world_model/model/__init__.py | 0 ding/world_model/model/networks.py | 165 +----- ding/world_model/model/tests/test_networks.py | 20 - .../tests/test_world_model_utils.py | 52 +- ding/world_model/utils.py | 328 ------------ 16 files changed, 735 insertions(+), 597 deletions(-) create mode 100644 ding/torch_utils/network/tests/test_dreamer.py create mode 100644 ding/world_model/model/__init__.py diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index 7e27e54777..19d24735b3 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -282,7 +282,7 @@ def serial_pipeline_dreamer( collect_kwargs = commander.step() # eval the policy if evaluator.should_eval(collector.envstep): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, policy_kwargs=dict(world_model=world_model)) if stop: break @@ -303,7 +303,7 @@ def serial_pipeline_dreamer( ) # fill environment buffer - data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + data = collector.collect(train_iter=learner.train_iter, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs)) env_buffer.push(data, cur_collector_envstep=collector.envstep) if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index d1e4f7d69b..7662bd17dc 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -6,7 +6,8 @@ import torch.nn as nn from torch.nn import functional as F -from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d, Conv2dSame, DreamerLayerNorm +from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d +from ding.torch_utils.network.dreamer import Conv2dSame, DreamerLayerNorm from ding.utils import SequenceType diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 84c612a7d2..5aa4b83a8d 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -6,7 +6,7 @@ from ding.utils import SequenceType, squeeze, MODEL_REGISTRY from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ FCEncoder, ConvEncoder, IMPALAConvEncoder -from ding.world_model.model.networks import ActionHead, DenseHead +from ding.torch_utils.network.dreamer import ActionHead, DenseHead @MODEL_REGISTRY.register('vac') @@ -438,8 +438,8 @@ def __init__( (255, ), value_layers, units, - act, - norm, + 'SiLU', # act + 'LN', # norm 'twohot_symlog', outscale=0.0, ) diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 8b6123c063..49a33b359d 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -26,6 +26,7 @@ from .td3_bc import TD3BCPolicy from .sac import SACPolicy, SACDiscretePolicy from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy +from .mbpolicy.dreamer import DREAMERPolicy from .qmix import QMIXPolicy from .wqmix import WQMIXPolicy from .collaq import CollaQPolicy @@ -302,6 +303,10 @@ class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy): class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy): pass +@POLICY_REGISTRY.register('dreamer_command') +class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): + pass + @POLICY_REGISTRY.register('cql_command') class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index a765b30ea5..54c1c62eed 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -1,9 +1,12 @@ -from typing import Dict, Any, List, Tuple +from typing import List, Dict, Any, Tuple, Union +from collections import namedtuple import torch from torch import nn from copy import deepcopy from ding.torch_utils import Adam, to_device +from ding.rl_utils import get_train_sample from ding.utils import POLICY_REGISTRY, deep_merge_dicts +from ding.utils.data import default_collate, default_decollate from ding.policy import Policy from ding.model import model_wrap from ding.policy.common_utils import default_preprocess_learn @@ -18,6 +21,10 @@ class DREAMERPolicy(Policy): type='dreamer', # (bool) Whether to use cuda for network and loss computation. cuda=False, + # (int) Number of training samples (randomly collected) in replay buffer when training starts. + random_collect_size=5000, + # (bool) Whether to need policy-specific data in preprocess transition. + transition_with_policy_data=False, # (int) imag_horizon=15, learn=dict( @@ -171,6 +178,129 @@ def _update(self, loss_dict): self._optimizer_value.step() return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm} + def _state_dict_learn(self) -> Dict[str, Any]: + ret = { + 'model': self._learn_model.state_dict(), + 'optimizer_value': self._optimizer_value.state_dict(), + 'optimizer_actor': self._optimizer_actor.state_dict(), + } + return ret + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + self._learn_model.load_state_dict(state_dict['model']) + self._optimizer_value.load_state_dict(state_dict['optimizer_value']) + self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) + + def _init_collect(self) -> None: + self._unroll_len = self._cfg.collect.unroll_len + self._collect_model = model_wrap(self._model, wrapper_name='base') + self._collect_model.reset() + + def _forward_collect(self, data: dict, world_model, envstep, state=None) -> dict: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._collect_model.eval() + + if state is None: + batch_size = len(data_id) + latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} + action = torch.zeros((batch_size, self._config.num_actions)).to( + self._config.device + ) + else: + latent, action = state + + data = data / 255.0 - 0.5 + embed = world_model.encoder(data) + latent, _ = world_model.dynamics.obs_step( + latent, action, embed, self._config.collect_dyn_sample + ) + feat = world_model.dynamics.get_feat(latent) + + actor = self._actor(feat) + action = actor.sample() + logprob = actor.log_prob(action) + latent = {k: v.detach() for k, v in latent.items()} + action = action.detach() + output = {"action": action, "logprob": logprob} + # to do + # should pass state to next collect + state = (latent, action) + + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + r""" + Overview: + Generate dict type transition data from inputs. + Arguments: + - obs (:obj:`Any`): Env observation + - model_output (:obj:`dict`): Output of collect model, including at least ['action'] + - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ + (here 'obs' indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': model_output['action'], + 'logprob': model_output['logprob'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return transition + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + return get_train_sample(data, self._unroll_len) + + def _init_eval(self) -> None: + self._eval_model = model_wrap(self._model, wrapper_name='base') + self._eval_model.reset() + + def _forward_eval(self, data: dict, world_model, state=None) -> dict: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + + if state is None: + batch_size = len(data_id) + latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} + action = torch.zeros((batch_size, self._config.num_actions)).to( + self._config.device + ) + else: + latent, action = state + + data = data / 255.0 - 0.5 + embed = world_model.encoder(data) + latent, _ = world_model.dynamics.obs_step( + latent, action, embed, self._config.collect_dyn_sample + ) + feat = world_model.dynamics.get_feat(latent) + + actor = self._actor(feat) + action = actor.mode() + logprob = actor.log_prob(action) + latent = {k: v.detach() for k, v in latent.items()} + action = action.detach() + output = {"action": action, "logprob": logprob} + # to do + # should pass state to next eval + state = (latent, action) + + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + def _monitor_vars_learn(self) -> List[str]: r""" Overview: diff --git a/ding/policy/mbpolicy/utils.py b/ding/policy/mbpolicy/utils.py index 8fdea26956..2ddd1ab0b1 100644 --- a/ding/policy/mbpolicy/utils.py +++ b/ding/policy/mbpolicy/utils.py @@ -3,7 +3,7 @@ from torch import Tensor from ding.torch_utils import fold_batch, unfold_batch from ding.rl_utils import generalized_lambda_returns -from ding.world_model.utils import static_scan +from ding.torch_utils.network.dreamer import static_scan def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor], diff --git a/ding/torch_utils/network/__init__.py b/ding/torch_utils/network/__init__.py index 1236146db7..dda50c339e 100644 --- a/ding/torch_utils/network/__init__.py +++ b/ding/torch_utils/network/__init__.py @@ -11,5 +11,5 @@ from .gumbel_softmax import GumbelSoftmax from .gtrxl import GTrXL, GRUGatingUnit from .popart import PopArt -from .dreamer import Conv2dSame, DreamerLayerNorm +#from .dreamer import Conv2dSame, DreamerLayerNorm, ActionHead, DenseHead from .merge import GatingType, SumMerge, VectorMerge diff --git a/ding/torch_utils/network/dreamer.py b/ding/torch_utils/network/dreamer.py index 8d45abc9cc..0fa9e9e55f 100644 --- a/ding/torch_utils/network/dreamer.py +++ b/ding/torch_utils/network/dreamer.py @@ -1,7 +1,11 @@ -import torch -import torch.nn as nn import math -from torch.nn import functional as F +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torch import distributions as torchd +from ding.torch_utils import MLP class Conv2dSame(torch.nn.Conv2d): @@ -38,4 +42,488 @@ def forward(self, x): x = x.permute(0, 2, 3, 1) x = self.norm(x) x = x.permute(0, 3, 1, 2) - return x \ No newline at end of file + return x + + +class DenseHead(nn.Module): + + def __init__( + self, + inp_dim, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter + shape, # (255,) + layer_num, + units, # 512 + act='SiLU', + norm='LN', + dist='normal', + std=1.0, + outscale=1.0, + ): + super(DenseHead, self).__init__() + self._shape = (shape, ) if isinstance(shape, int) else shape + if len(self._shape) == 0: + self._shape = (1, ) + self._layer_num = layer_num + self._units = units + self._act = getattr(torch.nn, act)() + self._norm = norm + self._dist = dist + self._std = std + + self.mlp = MLP( + inp_dim, + self._units, + self._units, + self._layer_num, + layer_fn=nn.Linear, + activation=self._act, + norm_type=self._norm + ) + self.mlp.apply(weight_init) + + self.mean_layer = nn.Linear(self._units, np.prod(self._shape)) + self.mean_layer.apply(uniform_weight_init(outscale)) + + if self._std == "learned": + self.std_layer = nn.Linear(self._units, np.prod(self._shape)) + self.std_layer.apply(uniform_weight_init(outscale)) + + def forward(self, features, dtype=None): + x = features + out = self.mlp(x) # (batch, time, _units=512) + mean = self.mean_layer(out) # (batch, time, 255) + if self._std == "learned": + std = self.std_layer(out) + else: + std = self._std + if self._dist == "normal": + return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) + if self._dist == "huber": + return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) + if self._dist == "binary": + return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) + if self._dist == "twohot_symlog": + return TwoHotDistSymlog(logits=mean) + raise NotImplementedError(self._dist) + + +class ActionHead(nn.Module): + + def __init__( + self, + inp_dim, + size, + layers, + units, + act=nn.ELU, + norm=nn.LayerNorm, + dist="trunc_normal", + init_std=0.0, + min_std=0.1, + max_std=1.0, + temp=0.1, + outscale=1.0, + unimix_ratio=0.01, + ): + super(ActionHead, self).__init__() + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = getattr(torch.nn, act) + self._norm = getattr(torch.nn, norm) + self._min_std = min_std + self._max_std = max_std + self._init_std = init_std + self._unimix_ratio = unimix_ratio + self._temp = temp() if callable(temp) else temp + + pre_layers = [] + for index in range(self._layers): + pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) + pre_layers.append(self._norm(self._units, eps=1e-03)) + pre_layers.append(self._act()) + if index == 0: + inp_dim = self._units + self._pre_layers = nn.Sequential(*pre_layers) + self._pre_layers.apply(weight_init) + + if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: + self._dist_layer = nn.Linear(self._units, 2 * self._size) + self._dist_layer.apply(uniform_weight_init(outscale)) + + elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: + self._dist_layer = nn.Linear(self._units, self._size) + self._dist_layer.apply(uniform_weight_init(outscale)) + + def __call__(self, features, dtype=None): + x = features + x = self._pre_layers(x) + if self._dist == "tanh_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = torch.tanh(mean) + std = F.softplus(std + self._init_std) + self._min_std + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) + dist = torchd.independent.Independent(dist, 1) + dist = SampleDist(dist) + elif self._dist == "tanh_normal_5": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = 5 * torch.tanh(mean / 5) + std = F.softplus(std + 5) + 5 + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) + dist = torchd.independent.Independent(dist, 1) + dist = SampleDist(dist) + elif self._dist == "normal": + x = self._dist_layer(x) + mean, std = torch.split(x, [self._size] * 2, -1) + std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std + dist = torchd.normal.Normal(torch.tanh(mean), std) + dist = ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "normal_1": + x = self._dist_layer(x) + dist = torchd.normal.Normal(mean, 1) + dist = ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "trunc_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, [self._size] * 2, -1) + mean = torch.tanh(mean) + std = 2 * torch.sigmoid(std / 2) + self._min_std + dist = SafeTruncatedNormal(mean, std, -1, 1) + dist = ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "onehot": + x = self._dist_layer(x) + dist = OneHotDist(x, unimix_ratio=self._unimix_ratio) + elif self._dist == "onehot_gumble": + x = self._dist_layer(x) + temp = self._temp + dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) + else: + raise NotImplementedError(self._dist) + return dist + + +def symlog(x): + return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + + +def symexp(x): + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + + +class SampleDist: + + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + samples = self._dist.sample(self._samples) + return torch.mean(samples, 0) + + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return sample[torch.argmax(logprob)][0] + + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -torch.mean(logprob, 0) + + +class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): + + def __init__(self, logits=None, probs=None, unimix_ratio=0.0): + if logits is not None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = torch.log(probs) + super().__init__(logits=logits, probs=None) + else: + super().__init__(logits=logits, probs=probs) + + def mode(self): + _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) + return _mode.detach() + super().logits - super().logits.detach() + + def sample(self, sample_shape=(), seed=None): + if seed is not None: + raise ValueError('need to check') + sample = super().sample(sample_shape) + probs = super().probs + while len(probs.shape) < len(sample.shape): + probs = probs[None] + sample += probs - probs.detach() + return sample + + +class TwoHotDistSymlog(): + + def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): + self.logits = logits + self.probs = torch.softmax(logits, -1) + self.buckets = torch.linspace(low, high, steps=255).to(device) + self.width = (self.buckets[-1] - self.buckets[0]) / 255 + + def mean(self): + print("mean called") + _mean = self.probs * self.buckets + return symexp(torch.sum(_mean, dim=-1, keepdim=True)) + + def mode(self): + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + + # Inside OneHotCategorical, log_prob is calculated using only max element in targets + def log_prob(self, x): + x = symlog(x) + # x(time, batch, 1) + below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 + above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) + below = torch.clip(below, 0, len(self.buckets) - 1) + above = torch.clip(above, 0, len(self.buckets) - 1) + equal = (below == above) + + dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) + dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) + total = dist_to_below + dist_to_above + weight_below = dist_to_above / total + weight_above = dist_to_below / total + target = ( + F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] + ) + log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) + target = target.squeeze(-2) + + return (target * log_pred).sum(-1) + + def log_prob_target(self, target): + log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) + return (target * log_pred).sum(-1) + + +class SymlogDist(): + + def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): + self._mode = mode + self._dist = dist + self._agg = agg + self._tol = tol + self._dim_to_reduce = dim_to_reduce + + def mode(self): + return symexp(self._mode) + + def mean(self): + return symexp(self._mode) + + def log_prob(self, value): + assert self._mode.shape == value.shape + if self._dist == 'mse': + distance = (self._mode - symlog(value)) ** 2.0 + distance = torch.where(distance < self._tol, 0, distance) + elif self._dist == 'abs': + distance = torch.abs(self._mode - symlog(value)) + distance = torch.where(distance < self._tol, 0, distance) + else: + raise NotImplementedError(self._dist) + if self._agg == 'mean': + loss = distance.mean(self._dim_to_reduce) + elif self._agg == 'sum': + loss = distance.sum(self._dim_to_reduce) + else: + raise NotImplementedError(self._agg) + return -loss + + +class ContDist: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + return self._dist.mean + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + return self._dist.log_prob(x) + + +class Bernoulli: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + _mode = torch.round(self._dist.mean) + return _mode.detach() + self._dist.mean - self._dist.mean.detach() + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + _logits = self._dist.base_dist.logits + log_probs0 = -F.softplus(_logits) + log_probs1 = -F.softplus(-_logits) + + return log_probs0 * (1 - x) + log_probs1 * x + + +class UnnormalizedHuber(torchd.normal.Normal): + + def __init__(self, loc, scale, threshold=1, **kwargs): + super().__init__(loc, scale, **kwargs) + self._threshold = threshold + + def log_prob(self, event): + return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) + + def mode(self): + return self.mean + + +class SafeTruncatedNormal(torchd.normal.Normal): + + def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + super().__init__(loc, scale) + self._low = low + self._high = high + self._clip = clip + self._mult = mult + + def sample(self, sample_shape): + event = super().sample(sample_shape) + if self._clip: + clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) + event = event - event.detach() + clipped.detach() + if self._mult: + event *= self._mult + return event + + +class TanhBijector(torchd.Transform): + + def __init__(self, validate_args=False, name='tanh'): + super().__init__() + + def _forward(self, x): + return torch.tanh(x) + + def _inverse(self, y): + y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) + y = torch.atanh(y) + return y + + def _forward_log_det_jacobian(self, x): + log2 = torch.math.log(2.0) + return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) + + +def static_scan(fn, inputs, start): + last = start # {logit:[batch_size, self._stoch, self._discrete], stoch:[batch_size, self._stoch, self._discrete], deter:[batch_size, self._deter]} + indices = range(inputs[0].shape[0]) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096)) + last = fn(last, *inp(index)) # post, prior + if flag: + if type(last) == type({}): + outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} + else: + outputs = [] + for _last in last: + if type(_last) == type({}): + outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) + else: + outputs.append(_last.clone().unsqueeze(0)) + flag = False + else: + if type(last) == type({}): + for key in last.keys(): + outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) + else: + for j in range(len(outputs)): + if type(last[j]) == type({}): + for key in last[j].keys(): + outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) + else: + outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) + if type(last) == type({}): + outputs = [outputs] + return outputs + + +def weight_init(m): + if isinstance(m, nn.Linear): + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + space = m.kernel_size[0] * m.kernel_size[1] + in_num = space * m.in_channels + out_num = space * m.out_channels + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + + +def uniform_weight_init(given_scale): + + def f(m): + if isinstance(m, nn.Linear): + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = given_scale / denoms + limit = np.sqrt(3 * scale) + nn.init.uniform_(m.weight.data, a=-limit, b=limit) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + + return f diff --git a/ding/torch_utils/network/tests/test_dreamer.py b/ding/torch_utils/network/tests/test_dreamer.py new file mode 100644 index 0000000000..77ca133cc0 --- /dev/null +++ b/ding/torch_utils/network/tests/test_dreamer.py @@ -0,0 +1,73 @@ +import pytest +from easydict import EasyDict +import torch +from torch import distributions as torchd +from itertools import product +from ding.torch_utils.network.dreamer import DenseHead, SampleDist, OneHotDist, TwoHotDistSymlog, SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init + + +# arguments +shape = [255, (255, ), ()] +# to do +# dist = ['normal', 'huber', 'binary', 'twohot_symlog'] +dist = ['twohot_symlog'] +args = list(product(*[shape, dist])) + + +@pytest.mark.unittest +@pytest.mark.parametrize('shape, dist', args) +def test_DenseHead(shape, dist): + in_dim, layer_num, units, B, time = 1536, 2, 512, 16, 64 + head = DenseHead(in_dim, shape, layer_num, units, dist=dist) + x = torch.randn(B, time, in_dim) + a = torch.randn(B, time, 1) + y = head(x) + assert y.mode().shape == (B, time, 1) + assert y.log_prob(a).shape == (B, time) + + +B, time = 16, 64 +mean = torch.randn(B, time, 255) +std = 1.0 +a = torch.randn(B, time, 1) # or torch.randn(B, time, 255) +sample_shape = torch.Size([]) + + +@pytest.mark.unittest +def test_ContDist(): + dist_origin = torchd.normal.Normal(mean, std) + dist = torchd.independent.Independent(dist_origin, 1) + dist_new = ContDist(dist) + assert dist_new.mode().shape == (B, time, 255) + assert dist_new.log_prob(a).shape == (B, time) + assert dist_origin.log_prob(a).shape == (B, time, 255) + assert dist_new.sample().shape == (B, time, 255) + + +@pytest.mark.unittest +def test_UnnormalizedHuber(): + dist_origin = UnnormalizedHuber(mean, std) + dist = torchd.independent.Independent(dist_origin, 1) + dist_new = ContDist(dist) + assert dist_new.mode().shape == (B, time, 255) + assert dist_new.log_prob(a).shape == (B, time) + assert dist_origin.log_prob(a).shape == (B, time, 255) + assert dist_new.sample().shape == (B, time, 255) + + +@pytest.mark.unittest +def test_Bernoulli(): + dist_origin = torchd.bernoulli.Bernoulli(logits=mean) + dist = torchd.independent.Independent(dist_origin, 1) + dist_new = Bernoulli(dist) + assert dist_new.mode().shape == (B, time, 255) + assert dist_new.log_prob(a).shape == (B, time, 255) + # to do + # assert dist_new.sample().shape == (B, time, 255) + + +@pytest.mark.unittest +def test_TwoHotDistSymlog(): + dist = TwoHotDistSymlog(logits=mean) + assert dist.mode().shape == (B, time, 1) + assert dist.log_prob(a).shape == (B, time) diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 3c5857c869..235802c894 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -186,6 +186,7 @@ def eval( envstep: int = -1, n_episode: Optional[int] = None, force_render: bool = False, + policy_kwargs: Optional[dict] = None, ) -> Tuple[bool, dict]: ''' Overview: @@ -229,7 +230,7 @@ def eval( if render: eval_monitor.update_video(self._env.ready_imgs) - policy_output = self._policy.forward(obs) + policy_output = self._policy.forward(obs, **policy_kwargs) actions = {i: a['action'] for i, a in policy_output.items()} actions = to_ndarray(actions) timesteps = self._env.step(actions) diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py index 031e2f390c..c9ef048630 100644 --- a/ding/world_model/dreamer.py +++ b/ding/world_model/dreamer.py @@ -7,8 +7,9 @@ from ding.utils.data import default_collate from ding.model import ConvEncoder from ding.world_model.base_world_model import WorldModel -from ding.world_model.model.networks import RSSM, ConvDecoder, DenseHead -from ding.torch_utils import fold_batch, unfold_batch, unsqueeze_repeat +from ding.world_model.model.networks import RSSM, ConvDecoder +from ding.torch_utils import to_device +from ding.torch_utils.network.dreamer import DenseHead @WORLD_MODEL_REGISTRY.register('dreamer') @@ -53,9 +54,8 @@ class DREAMERWorldModel(WorldModel, nn.Module): dyn_std_act='sigmoid2', dyn_temp_post=True, dyn_min_std=0.1, - dyn_cell=True, + dyn_cell='gru_layer_norm', unimix_ratio=0.01, - initial='learned', device='cpu', ), ) @@ -65,18 +65,21 @@ def __init__(self, cfg, env, tb_logger): nn.Module.__init__(self) self._cfg = cfg.model + #self._cfg.act = getattr(torch.nn, self._cfg.act), + #self._cfg.norm = getattr(torch.nn, self._cfg.norm), + self._cfg.act = nn.modules.activation.SiLU # nn.SiLU + self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm self.state_size = self._cfg.state_size self.action_size = self._cfg.action_size self.reward_size = self._cfg.reward_size self.hidden_size = self._cfg.hidden_size self.batch_size = self._cfg.batch_size - self.max_epochs_since_update = cfg.max_epochs_since_update if self._cuda: self.cuda() self.encoder = ConvEncoder( - *self.state_size, + self.state_size, hidden_size_list=[32, 64, 128, 256, 128], # to last layer 128? activation=torch.nn.SiLU(), kernel_size=self._cfg.encoder_kernels, @@ -103,7 +106,6 @@ def __init__(self, cfg, env, tb_logger): self._cfg.dyn_min_std, self._cfg.dyn_cell, self._cfg.unimix_ratio, - self._cfg.initial, self._cfg.action_size, self.embed_size, self._cfg.device, @@ -126,8 +128,8 @@ def __init__(self, cfg, env, tb_logger): (255, ), self._cfg.reward_layers, self._cfg.units, - self._cfg.act, - self._cfg.norm, + 'SiLU', # self._cfg.act + 'LN', # self._cfg.norm dist=self._cfg.reward_head, outscale=0.0, ) @@ -137,8 +139,8 @@ def __init__(self, cfg, env, tb_logger): [], self._cfg.discount_layers, self._cfg.units, - self._cfg.act, - self._cfg.norm, + 'SiLU', # self._cfg.act + 'LN', # self._cfg.norm dist="binary", ) # to do @@ -157,15 +159,14 @@ def train(self, env_buffer, envstep, train_iter, batch_size): data = default_collate(data) data['done'] = data['done'].float() data['weight'] = data.get('weight', None) - data = {k: torch.Tensor(v).to(self._cfg.device) for k, v in data.items()} - #image = data['obs'] - action = data['action'] - reward = data['reward'] - next_obs = data['next_obs'] - if len(reward.shape) == 2: - reward = reward.unsqueeze(-1) - if len(action.shape) == 2: - action = action.unsqueeze(-1) + data['obs'] = data['obs'] / 255.0 - 0.5 + next_obs = data['next_obs'] / 255.0 - 0.5 + #data = {k: v.to(self._cfg.device) for k, v in data.items()} + data = to_device(data, self._cfg.device) + if len(data['reward'].shape) == 2: + data['reward'] = data['reward'].unsqueeze(-1) + if len(data['action'].shape) == 2: + data['action'] = data['action'].unsqueeze(-1) embed = self.encoder(data['obs']) post, prior = self.dynamics.observe(embed, data["action"]) diff --git a/ding/world_model/model/__init__.py b/ding/world_model/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ding/world_model/model/networks.py b/ding/world_model/model/networks.py index 066020dc02..5c29125f44 100644 --- a/ding/world_model/model/networks.py +++ b/ding/world_model/model/networks.py @@ -6,9 +6,7 @@ import torch.nn.functional as F from torch import distributions as torchd -from ding.world_model.utils import weight_init, uniform_weight_init, OneHotDist, ContDist, SymlogDist, SampleDist, \ - Bernoulli, TwoHotDistSymlog, UnnormalizedHuber, SafeTruncatedNormal, TanhBijector, static_scan -from ding.torch_utils import MLP, DreamerLayerNorm +from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, OneHotDist, ContDist, SymlogDist, static_scan, DreamerLayerNorm class RSSM(nn.Module): @@ -368,167 +366,6 @@ def __call__(self, features, dtype=None): return SymlogDist(mean) -class DenseHead(nn.Module): - - def __init__( - self, - inp_dim, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter - shape, # (255,) - layer_num, - units, # 512 - act='SiLU', - norm='LN', - dist='normal', - std=1.0, - outscale=1.0, - ): - super(DenseHead, self).__init__() - self._shape = (shape, ) if isinstance(shape, int) else shape - if len(self._shape) == 0: - self._shape = (1, ) - self._layer_num = layer_num - self._units = units - self._act = getattr(torch.nn, act)() - self._norm = norm - self._dist = dist - self._std = std - - self.mlp = MLP( - inp_dim, - self._units, - self._units, - self._layer_num, - layer_fn=nn.Linear, - activation=self._act, - norm_type=self._norm - ) - self.mlp.apply(weight_init) - - self.mean_layer = nn.Linear(self._units, np.prod(self._shape)) - self.mean_layer.apply(uniform_weight_init(outscale)) - - if self._std == "learned": - self.std_layer = nn.Linear(self._units, np.prod(self._shape)) - self.std_layer.apply(uniform_weight_init(outscale)) - - def forward(self, features, dtype=None): - x = features - out = self.mlp(x) # (batch, time, _units=512) - mean = self.mean_layer(out) # (batch, time, 255) - if self._std == "learned": - std = self.std_layer(out) - else: - std = self._std - if self._dist == "normal": - return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) - if self._dist == "huber": - return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) - if self._dist == "binary": - return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) - if self._dist == "twohot_symlog": - return TwoHotDistSymlog(logits=mean) - raise NotImplementedError(self._dist) - - -class ActionHead(nn.Module): - - def __init__( - self, - inp_dim, - size, - layers, - units, - act=nn.ELU, - norm=nn.LayerNorm, - dist="trunc_normal", - init_std=0.0, - min_std=0.1, - max_std=1.0, - temp=0.1, - outscale=1.0, - unimix_ratio=0.01, - ): - super(ActionHead, self).__init__() - self._size = size - self._layers = layers - self._units = units - self._dist = dist - self._act = act - self._norm = norm - self._min_std = min_std - self._max_std = max_std - self._init_std = init_std - self._unimix_ratio = unimix_ratio - self._temp = temp() if callable(temp) else temp - - pre_layers = [] - for index in range(self._layers): - pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) - pre_layers.append(norm(self._units, eps=1e-03)) - pre_layers.append(act()) - if index == 0: - inp_dim = self._units - self._pre_layers = nn.Sequential(*pre_layers) - self._pre_layers.apply(weight_init) - - if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: - self._dist_layer = nn.Linear(self._units, 2 * self._size) - self._dist_layer.apply(uniform_weight_init(outscale)) - - elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: - self._dist_layer = nn.Linear(self._units, self._size) - self._dist_layer.apply(uniform_weight_init(outscale)) - - def __call__(self, features, dtype=None): - x = features - x = self._pre_layers(x) - if self._dist == "tanh_normal": - x = self._dist_layer(x) - mean, std = torch.split(x, 2, -1) - mean = torch.tanh(mean) - std = F.softplus(std + self._init_std) + self._min_std - dist = torchd.normal.Normal(mean, std) - dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) - dist = torchd.independent.Independent(dist, 1) - dist = SampleDist(dist) - elif self._dist == "tanh_normal_5": - x = self._dist_layer(x) - mean, std = torch.split(x, 2, -1) - mean = 5 * torch.tanh(mean / 5) - std = F.softplus(std + 5) + 5 - dist = torchd.normal.Normal(mean, std) - dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) - dist = torchd.independent.Independent(dist, 1) - dist = SampleDist(dist) - elif self._dist == "normal": - x = self._dist_layer(x) - mean, std = torch.split(x, [self._size] * 2, -1) - std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std - dist = torchd.normal.Normal(torch.tanh(mean), std) - dist = ContDist(torchd.independent.Independent(dist, 1)) - elif self._dist == "normal_1": - x = self._dist_layer(x) - dist = torchd.normal.Normal(mean, 1) - dist = ContDist(torchd.independent.Independent(dist, 1)) - elif self._dist == "trunc_normal": - x = self._dist_layer(x) - mean, std = torch.split(x, [self._size] * 2, -1) - mean = torch.tanh(mean) - std = 2 * torch.sigmoid(std / 2) + self._min_std - dist = SafeTruncatedNormal(mean, std, -1, 1) - dist = ContDist(torchd.independent.Independent(dist, 1)) - elif self._dist == "onehot": - x = self._dist_layer(x) - dist = OneHotDist(x, unimix_ratio=self._unimix_ratio) - elif self._dist == "onehot_gumble": - x = self._dist_layer(x) - temp = self._temp - dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) - else: - raise NotImplementedError(self._dist) - return dist - - class GRUCell(nn.Module): def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): diff --git a/ding/world_model/model/tests/test_networks.py b/ding/world_model/model/tests/test_networks.py index 234f42fffc..c23c94cd3d 100644 --- a/ding/world_model/model/tests/test_networks.py +++ b/ding/world_model/model/tests/test_networks.py @@ -1,23 +1,3 @@ import pytest import torch from itertools import product -from ding.world_model.model.networks import DenseHead - -# arguments -shape = [255, (255, ), ()] -# to do -# dist = ['normal', 'huber', 'binary', 'twohot_symlog'] -dist = ['twohot_symlog'] -args = list(product(*[shape, dist])) - - -@pytest.mark.unittest -@pytest.mark.parametrize('shape, dist', args) -def test_DenseHead(shape, dist): - in_dim, layer_num, units, B, time = 1536, 2, 512, 16, 64 - head = DenseHead(in_dim, shape, layer_num, units, dist=dist) - x = torch.randn(B, time, in_dim) - a = torch.randn(B, time, 1) - y = head(x) - assert y.mode().shape == (B, time, 1) - assert y.log_prob(a).shape == (B, time) diff --git a/ding/world_model/tests/test_world_model_utils.py b/ding/world_model/tests/test_world_model_utils.py index d4477cd178..26ba5e7f5e 100644 --- a/ding/world_model/tests/test_world_model_utils.py +++ b/ding/world_model/tests/test_world_model_utils.py @@ -1,9 +1,6 @@ import pytest from easydict import EasyDict -import torch -from torch import distributions as torchd -from itertools import product -from ding.world_model.utils import get_rollout_length_scheduler, SampleDist, OneHotDist, TwoHotDistSymlog, SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init +from ding.world_model.utils import get_rollout_length_scheduler @pytest.mark.unittest @@ -20,50 +17,3 @@ def test_get_rollout_length_scheduler(): assert scheduler(19999) == 1 assert scheduler(150000) == 25 assert scheduler(1500000) == 25 - - -B, time = 16, 64 -mean = torch.randn(B, time, 255) -std = 1.0 -a = torch.randn(B, time, 1) # or torch.randn(B, time, 255) -sample_shape = torch.Size([]) - - -@pytest.mark.unittest -def test_ContDist(): - dist_origin = torchd.normal.Normal(mean, std) - dist = torchd.independent.Independent(dist_origin, 1) - dist_new = ContDist(dist) - assert dist_new.mode().shape == (B, time, 255) - assert dist_new.log_prob(a).shape == (B, time) - assert dist_origin.log_prob(a).shape == (B, time, 255) - assert dist_new.sample().shape == (B, time, 255) - - -@pytest.mark.unittest -def test_UnnormalizedHuber(): - dist_origin = UnnormalizedHuber(mean, std) - dist = torchd.independent.Independent(dist_origin, 1) - dist_new = ContDist(dist) - assert dist_new.mode().shape == (B, time, 255) - assert dist_new.log_prob(a).shape == (B, time) - assert dist_origin.log_prob(a).shape == (B, time, 255) - assert dist_new.sample().shape == (B, time, 255) - - -@pytest.mark.unittest -def test_Bernoulli(): - dist_origin = torchd.bernoulli.Bernoulli(logits=mean) - dist = torchd.independent.Independent(dist_origin, 1) - dist_new = Bernoulli(dist) - assert dist_new.mode().shape == (B, time, 255) - assert dist_new.log_prob(a).shape == (B, time, 255) - # to do - # assert dist_new.sample().shape == (B, time, 255) - - -@pytest.mark.unittest -def test_TwoHotDistSymlog(): - dist = TwoHotDistSymlog(logits=mean) - assert dist.mode().shape == (B, time, 1) - assert dist.log_prob(a).shape == (B, time) diff --git a/ding/world_model/utils.py b/ding/world_model/utils.py index 06d660af25..15172699f9 100644 --- a/ding/world_model/utils.py +++ b/ding/world_model/utils.py @@ -1,10 +1,5 @@ from easydict import EasyDict from typing import Callable -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F -from torch import distributions as torchd def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: @@ -28,326 +23,3 @@ def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: return lambda x: cfg.rollout_length else: raise KeyError("not implemented key: {}".format(cfg.type)) - - -def symlog(x): - return torch.sign(x) * torch.log(torch.abs(x) + 1.0) - - -def symexp(x): - return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) - - -class SampleDist: - - def __init__(self, dist, samples=100): - self._dist = dist - self._samples = samples - - @property - def name(self): - return 'SampleDist' - - def __getattr__(self, name): - return getattr(self._dist, name) - - def mean(self): - samples = self._dist.sample(self._samples) - return torch.mean(samples, 0) - - def mode(self): - sample = self._dist.sample(self._samples) - logprob = self._dist.log_prob(sample) - return sample[torch.argmax(logprob)][0] - - def entropy(self): - sample = self._dist.sample(self._samples) - logprob = self.log_prob(sample) - return -torch.mean(logprob, 0) - - -class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): - - def __init__(self, logits=None, probs=None, unimix_ratio=0.0): - if logits is not None and unimix_ratio > 0.0: - probs = F.softmax(logits, dim=-1) - probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] - logits = torch.log(probs) - super().__init__(logits=logits, probs=None) - else: - super().__init__(logits=logits, probs=probs) - - def mode(self): - _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) - return _mode.detach() + super().logits - super().logits.detach() - - def sample(self, sample_shape=(), seed=None): - if seed is not None: - raise ValueError('need to check') - sample = super().sample(sample_shape) - probs = super().probs - while len(probs.shape) < len(sample.shape): - probs = probs[None] - sample += probs - probs.detach() - return sample - - -class TwoHotDistSymlog(): - - def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): - self.logits = logits - self.probs = torch.softmax(logits, -1) - self.buckets = torch.linspace(low, high, steps=255).to(device) - self.width = (self.buckets[-1] - self.buckets[0]) / 255 - - def mean(self): - print("mean called") - _mean = self.probs * self.buckets - return symexp(torch.sum(_mean, dim=-1, keepdim=True)) - - def mode(self): - _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) - - # Inside OneHotCategorical, log_prob is calculated using only max element in targets - def log_prob(self, x): - x = symlog(x) - # x(time, batch, 1) - below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 - above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) - below = torch.clip(below, 0, len(self.buckets) - 1) - above = torch.clip(above, 0, len(self.buckets) - 1) - equal = (below == above) - - dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) - dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) - total = dist_to_below + dist_to_above - weight_below = dist_to_above / total - weight_above = dist_to_below / total - target = ( - F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + - F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] - ) - log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) - target = target.squeeze(-2) - - return (target * log_pred).sum(-1) - - def log_prob_target(self, target): - log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) - return (target * log_pred).sum(-1) - - -class SymlogDist(): - - def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): - self._mode = mode - self._dist = dist - self._agg = agg - self._tol = tol - self._dim_to_reduce = dim_to_reduce - - def mode(self): - return symexp(self._mode) - - def mean(self): - return symexp(self._mode) - - def log_prob(self, value): - assert self._mode.shape == value.shape - if self._dist == 'mse': - distance = (self._mode - symlog(value)) ** 2.0 - distance = torch.where(distance < self._tol, 0, distance) - elif self._dist == 'abs': - distance = torch.abs(self._mode - symlog(value)) - distance = torch.where(distance < self._tol, 0, distance) - else: - raise NotImplementedError(self._dist) - if self._agg == 'mean': - loss = distance.mean(self._dim_to_reduce) - elif self._agg == 'sum': - loss = distance.sum(self._dim_to_reduce) - else: - raise NotImplementedError(self._agg) - return -loss - - -class ContDist: - - def __init__(self, dist=None): - super().__init__() - self._dist = dist - self.mean = dist.mean - - def __getattr__(self, name): - return getattr(self._dist, name) - - def entropy(self): - return self._dist.entropy() - - def mode(self): - return self._dist.mean - - def sample(self, sample_shape=()): - return self._dist.rsample(sample_shape) - - def log_prob(self, x): - return self._dist.log_prob(x) - - -class Bernoulli: - - def __init__(self, dist=None): - super().__init__() - self._dist = dist - self.mean = dist.mean - - def __getattr__(self, name): - return getattr(self._dist, name) - - def entropy(self): - return self._dist.entropy() - - def mode(self): - _mode = torch.round(self._dist.mean) - return _mode.detach() + self._dist.mean - self._dist.mean.detach() - - def sample(self, sample_shape=()): - return self._dist.rsample(sample_shape) - - def log_prob(self, x): - _logits = self._dist.base_dist.logits - log_probs0 = -F.softplus(_logits) - log_probs1 = -F.softplus(-_logits) - - return log_probs0 * (1 - x) + log_probs1 * x - - -class UnnormalizedHuber(torchd.normal.Normal): - - def __init__(self, loc, scale, threshold=1, **kwargs): - super().__init__(loc, scale, **kwargs) - self._threshold = threshold - - def log_prob(self, event): - return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) - - def mode(self): - return self.mean - - -class SafeTruncatedNormal(torchd.normal.Normal): - - def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): - super().__init__(loc, scale) - self._low = low - self._high = high - self._clip = clip - self._mult = mult - - def sample(self, sample_shape): - event = super().sample(sample_shape) - if self._clip: - clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) - event = event - event.detach() + clipped.detach() - if self._mult: - event *= self._mult - return event - - -class TanhBijector(torchd.Transform): - - def __init__(self, validate_args=False, name='tanh'): - super().__init__() - - def _forward(self, x): - return torch.tanh(x) - - def _inverse(self, y): - y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) - y = torch.atanh(y) - return y - - def _forward_log_det_jacobian(self, x): - log2 = torch.math.log(2.0) - return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) - - -def static_scan(fn, inputs, start): - last = start # {logit:[batch_size, self._stoch, self._discrete], stoch:[batch_size, self._stoch, self._discrete], deter:[batch_size, self._deter]} - indices = range(inputs[0].shape[0]) - flag = True - for index in indices: - inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096)) - last = fn(last, *inp(index)) # post, prior - if flag: - if type(last) == type({}): - outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} - else: - outputs = [] - for _last in last: - if type(_last) == type({}): - outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) - else: - outputs.append(_last.clone().unsqueeze(0)) - flag = False - else: - if type(last) == type({}): - for key in last.keys(): - outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) - else: - for j in range(len(outputs)): - if type(last[j]) == type({}): - for key in last[j].keys(): - outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) - else: - outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) - if type(last) == type({}): - outputs = [outputs] - return outputs - - -def weight_init(m): - if isinstance(m, nn.Linear): - in_num = m.in_features - out_num = m.out_features - denoms = (in_num + out_num) / 2.0 - scale = 1.0 / denoms - std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - space = m.kernel_size[0] * m.kernel_size[1] - in_num = space * m.in_channels - out_num = space * m.out_channels - denoms = (in_num + out_num) / 2.0 - scale = 1.0 / denoms - std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - elif isinstance(m, nn.LayerNorm): - m.weight.data.fill_(1.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - - -def uniform_weight_init(given_scale): - - def f(m): - if isinstance(m, nn.Linear): - in_num = m.in_features - out_num = m.out_features - denoms = (in_num + out_num) / 2.0 - scale = given_scale / denoms - limit = np.sqrt(3 * scale) - nn.init.uniform_(m.weight.data, a=-limit, b=limit) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - elif isinstance(m, nn.LayerNorm): - m.weight.data.fill_(1.0) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) - - return f From b4abf72ae6fdbbc65021f502211d1091db90e65d Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Wed, 21 Jun 2023 12:29:04 +0800 Subject: [PATCH 08/14] first version pipeline required polish --- ding/entry/serial_entry_mbrl.py | 1 + ding/envs/env_wrappers/env_wrappers.py | 13 +++++- ding/policy/mbpolicy/dreamer.py | 13 +++--- ding/policy/mbpolicy/utils.py | 10 ++-- .../collector/interaction_serial_evaluator.py | 3 ++ ding/world_model/dreamer.py | 46 +++++++++++++------ ding/world_model/model/networks.py | 4 +- ding/world_model/tests/test_dreamer.py | 33 +++++++++++++ .../dmc2gym/config/dmc2gym_dreamer_config.py | 11 +++-- .../config/dmc2gym_sac_pixel_config.py | 5 ++ dizoo/dmc2gym/envs/dmc2gym_env.py | 3 +- 11 files changed, 108 insertions(+), 34 deletions(-) create mode 100644 ding/world_model/tests/test_dreamer.py diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index 19d24735b3..af6fef2611 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -276,6 +276,7 @@ def serial_pipeline_dreamer( # prefill environment buffer if cfg.policy.get('random_collect_size', 0) > 0: + cfg.policy.random_collect_size = cfg.policy.random_collect_size // cfg.policy.collect.unroll_len random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer) while True: diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 1a75b88179..f158ff208d 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -182,8 +182,17 @@ def observation(self, frame): import sys logging.warning("Please install opencv-python first.") sys.exit(1) - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) + # to do + # channel_first + if frame.shape[0]<10: + frame = frame.transpose(1, 2, 0) + frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) + frame = frame.transpose(2, 0, 1) + else: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) + + return frame @ENV_WRAPPER_REGISTRY.register('scaled_float_frame') diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 54c1c62eed..2fdf2183e6 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -86,7 +86,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: # log dict log_vars = {} self._learn_model.train() - world_model.requires_grad_(requires_grad=False) + self._actor.requires_grad_(requires_grad=True) # start is dict of {stoch, deter, logit} if self._cuda: @@ -102,7 +102,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: # this target is not scaled # slow is flag to indicate whether slow_target is used for lambda-return target, weights, base = compute_target( - self._cfg.learn, world_model, self._critic, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent + self._cfg.learn, world_model, self._critic, imag_feat, imag_state, reward, actor_ent, state_ent ) actor_loss, mets = compute_actor_loss( self._cfg.learn, @@ -124,7 +124,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: self._critic.requires_grad_(requires_grad=True) value = self._critic(value_input[:-1].detach()) # to do - target = torch.stack(target, dim=1) + # target = torch.stack(target, dim=1) # (time, batch, 1), (time, batch, 1) -> (time, batch) value_loss = -value.log_prob(target.detach()) slow_target = self._slow_value(value_input[:-1].detach()) @@ -140,7 +140,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: log_vars.update(tensorstats(target, "target")) log_vars.update(tensorstats(reward, "imag_reward")) log_vars.update(tensorstats(imag_action, "imag_action")) - log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy() + log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy().item() # ==================== # actor-critic update # ==================== @@ -248,9 +248,10 @@ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple """ transition = { 'obs': obs, - 'next_obs': timestep.obs, + #'next_obs': timestep.obs, 'action': model_output['action'], - 'logprob': model_output['logprob'], + # TODO(zp) random_collect just have action + #'logprob': model_output['logprob'], 'reward': timestep.reward, 'done': timestep.done, } diff --git a/ding/policy/mbpolicy/utils.py b/ding/policy/mbpolicy/utils.py index 2ddd1ab0b1..64385b27d7 100644 --- a/ding/policy/mbpolicy/utils.py +++ b/ding/policy/mbpolicy/utils.py @@ -101,18 +101,18 @@ def compute_actor_loss( adv = normed_target - normed_base metrics.update(tensorstats(normed_target, "normed_target")) values = reward_ema.values - metrics["EMA_005"] = values[0].detach().cpu().numpy() - metrics["EMA_095"] = values[1].detach().cpu().numpy() + metrics["EMA_005"] = values[0].detach().cpu().numpy().item() + metrics["EMA_095"] = values[1].detach().cpu().numpy().item() actor_target = adv if cfg.actor_entropy > 0: actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] actor_target += actor_entropy - metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy() + metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item() if cfg.actor_state_entropy > 0: state_entropy = cfg.actor_state_entropy * state_ent[:-1] actor_target += state_entropy - metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy() + metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item() actor_loss = -torch.mean(weights[:-1] * actor_target) return actor_loss, metrics @@ -143,5 +143,5 @@ def tensorstats(tensor, prefix=None): 'max': torch.max(tensor).detach().cpu().numpy(), } if prefix: - metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} + metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()} return metrics diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 235802c894..24da981c1f 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -206,6 +206,9 @@ def eval( dist.reduce(envstep_tensor, dst=0) envstep = envstep_tensor.item() + if policy_kwargs is None: + policy_kwargs = {} + # evaluator only work on rank0 stop_flag, return_info = False, [] if get_rank() == 0: diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py index c9ef048630..6743c3281a 100644 --- a/ding/world_model/dreamer.py +++ b/ding/world_model/dreamer.py @@ -64,6 +64,7 @@ def __init__(self, cfg, env, tb_logger): WorldModel.__init__(self, cfg, env, tb_logger) nn.Module.__init__(self) + self.pretrain_flag = True self._cfg = cfg.model #self._cfg.act = getattr(torch.nn, self._cfg.act), #self._cfg.norm = getattr(torch.nn, self._cfg.norm), @@ -80,7 +81,7 @@ def __init__(self, cfg, env, tb_logger): self.encoder = ConvEncoder( self.state_size, - hidden_size_list=[32, 64, 128, 256, 128], # to last layer 128? + hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? activation=torch.nn.SiLU(), kernel_size=self._cfg.encoder_kernels, layer_norm=True @@ -153,22 +154,39 @@ def step(self, obs, act): def eval(self, env_buffer, envstep, train_iter): pass + def should_pretrain(self): + if self.pretrain_flag: + self.pretrain_flag = False + return True + return False + def train(self, env_buffer, envstep, train_iter, batch_size): self.last_train_step = envstep + # [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] data = env_buffer.sample(batch_size, train_iter) - data = default_collate(data) - data['done'] = data['done'].float() + data = default_collate(data) # -> {some_key: T lists}, each list is [B, some_dim] + data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])} + + data['discount'] = 1.0 - data['done'].float() data['weight'] = data.get('weight', None) - data['obs'] = data['obs'] / 255.0 - 0.5 - next_obs = data['next_obs'] / 255.0 - 0.5 + data['image'] = data['obs'] + #data['obs'] = data['obs'] / 255.0 - 0.5 + #next_obs = data['next_obs'] / 255.0 - 0.5 #data = {k: v.to(self._cfg.device) for k, v in data.items()} data = to_device(data, self._cfg.device) if len(data['reward'].shape) == 2: data['reward'] = data['reward'].unsqueeze(-1) if len(data['action'].shape) == 2: data['action'] = data['action'].unsqueeze(-1) + if len(data['discount'].shape) == 2: + data['discount'] = data['discount'].unsqueeze(-1) + + self.requires_grad_(requires_grad=True) - embed = self.encoder(data['obs']) + image = data['obs'].reshape([-1] + list(data['obs'].shape[-3:])) + embed = self.encoder(image) + embed = embed.reshape(list(data['obs'].shape[:-3]) + [embed.shape[-1]]) + post, prior = self.dynamics.observe(embed, data["action"]) kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale @@ -182,7 +200,7 @@ def train(self, env_buffer, envstep, train_iter, batch_size): pred = head(feat) like = pred.log_prob(data[name]) likes[name] = like - losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) + losses[name] = -torch.mean(like) model_loss = sum(losses.values()) + kl_loss # ==================== @@ -191,22 +209,24 @@ def train(self, env_buffer, envstep, train_iter, batch_size): self.optimizer.zero_grad() model_loss.backward() self.optimizer.step() + + self.requires_grad_(requires_grad=False) # log if self.tb_logger is not None: for name, loss in losses.items(): - self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy(), envstep) + self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep) self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep) self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep) self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep) - self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy(), envstep) - self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy(), envstep) - self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy(), envstep) + self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep) + self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep) + self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep) prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy() post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy() - self.tb_logger.add_scalar('prior_ent', prior_ent, envstep) - self.tb_logger.add_scalar('post_ent', post_ent, envstep) + self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep) + self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep) context = dict( embed=embed, diff --git a/ding/world_model/model/networks.py b/ding/world_model/model/networks.py index 5c29125f44..c753b16351 100644 --- a/ding/world_model/model/networks.py +++ b/ding/world_model/model/networks.py @@ -361,8 +361,8 @@ def __call__(self, features, dtype=None): x = x.reshape([-1, 4, 4, self._embed_size // 16]) x = x.permute(0, 3, 1, 2) x = self.layers(x) - mean = x.reshape(features.shape[:-1] + self._shape) - mean = mean.permute(0, 1, 3, 4, 2) + mean = x.reshape(list(features.shape[:-1]) + self._shape) + #mean = mean.permute(0, 1, 3, 4, 2) return SymlogDist(mean) diff --git a/ding/world_model/tests/test_dreamer.py b/ding/world_model/tests/test_dreamer.py new file mode 100644 index 0000000000..4260be4801 --- /dev/null +++ b/ding/world_model/tests/test_dreamer.py @@ -0,0 +1,33 @@ +import pytest +import torch + +from itertools import product +from easydict import EasyDict +from ding.world_model.dreamer import DREAMERWorldModel +from ding.utils import deep_merge_dicts + +# arguments +state_size = [3,64,64] +action_size = [6, 1] +args = list(product(*[state_size, action_size])) + + +@pytest.mark.unittest +class TestDREAMER: + + def get_world_model(self, state_size, action_size): + cfg = DREAMERWorldModel.default_config() + cfg.model.max_epochs_since_update = 0 + cfg = deep_merge_dicts( + cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1)) + ) + fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool()) + return DREAMERWorldModel(cfg, fake_env, None) + + @pytest.mark.parametrize('state_size, action_size', args) + def test_train(self, state_size, action_size): + states = torch.rand(1280, *state_size) + actions = torch.rand(1280, action_size) + + model = self.get_world_model(state_size, action_size) + diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index fb2c5d97a7..c1cffc8acd 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -14,12 +14,13 @@ warp_frame=True, scale=True, clip_rewards=False, - frame_stack=3, + frame_stack=1, from_pixels=True, + resize=64, collector_env_num=1, evaluator_env_num=8, n_evaluator_episode=8, - stop_value=100000, + stop_value=50, # 800 ), policy=dict( cuda=cuda, @@ -33,14 +34,14 @@ learn=dict( lambda_=0.95, learning_rate=0.001, - batch_size=256, + batch_size=16, imag_sample=True, discount=0.997, reward_EMA=True, ), collect=dict( n_sample=1, - unroll_len=1, + unroll_len=64, ), command=dict(), eval=dict(evaluator=dict(eval_freq=10000, )), # w.r.t envstep @@ -62,7 +63,7 @@ reward_size=1, #hidden_size=200, #use_decay=True, - batch_size=256, + batch_size=16, #holdout_ratio=0.1, #max_epochs_since_update=5, #deterministic_rollout=True, diff --git a/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py b/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py index 76669b057e..c0155b1ebd 100644 --- a/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py @@ -72,3 +72,8 @@ ) dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config) create_config = dmc2gym_sac_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c ant_sac_config.py -s 0 --env-step 1e7` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file diff --git a/dizoo/dmc2gym/envs/dmc2gym_env.py b/dizoo/dmc2gym/envs/dmc2gym_env.py index 9e97629897..4aad02eb43 100644 --- a/dizoo/dmc2gym/envs/dmc2gym_env.py +++ b/dizoo/dmc2gym/envs/dmc2gym_env.py @@ -120,6 +120,7 @@ def __init__(self, cfg: dict = {}) -> None: "height": 84, "width": 84, "channels_first": True, + "resize": 84, } self._cfg.update(cfg) @@ -154,7 +155,7 @@ def reset(self) -> np.ndarray: # optional env wrapper if self._cfg['warp_frame']: - self._env = WarpFrameWrapper(self._env) + self._env = WarpFrameWrapper(self._env, size=self._cfg['resize']) if self._cfg['scale']: self._env = ScaledFloatFrameWrapper(self._env) if self._cfg['clip_rewards']: From c1447890cfbe524322657692d8988a3af3d2cce7 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Fri, 23 Jun 2023 15:28:07 +0800 Subject: [PATCH 09/14] fix multi-env collect latent-state and add action affine_transform in dmc2gym_env --- ding/entry/serial_entry_mbrl.py | 2 +- ding/entry/utils.py | 2 +- ding/policy/mbpolicy/dreamer.py | 31 +++++++++++++------ .../collector/sample_serial_collector.py | 16 ++++++++-- .../dmc2gym/config/dmc2gym_dreamer_config.py | 2 ++ dizoo/dmc2gym/envs/dmc2gym_env.py | 2 ++ 6 files changed, 41 insertions(+), 14 deletions(-) diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index af6fef2611..2f2cc45856 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -293,7 +293,7 @@ def serial_pipeline_dreamer( if world_model.should_pretrain() else int(world_model.should_train(collector.envstep)) ) - for _ in range(steps): + for _ in range(1): batch_size = learner.policy.get_attribute('batch_size') post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size) diff --git a/ding/entry/utils.py b/ding/entry/utils.py index 8b8d1d6626..38c85889df 100644 --- a/ding/entry/utils.py +++ b/ding/entry/utils.py @@ -60,7 +60,7 @@ def random_collect( new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs) else: new_data = collector.collect( - n_sample=policy_cfg.random_collect_size, record_random_collect=False, policy_kwargs=collect_kwargs + n_sample=policy_cfg.random_collect_size, random_collect=True, record_random_collect=False, policy_kwargs=collect_kwargs ) # 'record_random_collect=False' means random collect without output log if postprocess_data_fn is not None: new_data = postprocess_data_fn(new_data) diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 2fdf2183e6..e24667d6ac 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -196,7 +196,7 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._model, wrapper_name='base') self._collect_model.reset() - def _forward_collect(self, data: dict, world_model, envstep, state=None) -> dict: + def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=None) -> dict: data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -206,16 +206,27 @@ def _forward_collect(self, data: dict, world_model, envstep, state=None) -> dict if state is None: batch_size = len(data_id) latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} - action = torch.zeros((batch_size, self._config.num_actions)).to( - self._config.device + action = torch.zeros((batch_size, self._cfg.collect.action_size)).to( + self._device ) else: - latent, action = state + #state = default_collate(list(state.values())) + latent = default_collate(list(zip(*state))[0]) + action = default_collate(list(zip(*state))[1]) + if len(action.shape)==1: + action = action.unsqueeze(-1) + if any(reset): + mask = 1 - reset + for key in latent.keys(): + for i in range(latent[key].shape[0]): + latent[key][i] *= mask[i] + for i in range(len(action)): + action[i] *= mask[i] - data = data / 255.0 - 0.5 + #data = data / 255.0 - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step( - latent, action, embed, self._config.collect_dyn_sample + latent, action, embed, self._cfg.collect.collect_dyn_sample ) feat = world_model.dynamics.get_feat(latent) @@ -224,10 +235,10 @@ def _forward_collect(self, data: dict, world_model, envstep, state=None) -> dict logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() - output = {"action": action, "logprob": logprob} # to do # should pass state to next collect state = (latent, action) + output = {"action": action, "logprob": logprob, "state": state} if self._cuda: output = to_device(output, 'cpu') @@ -274,8 +285,8 @@ def _forward_eval(self, data: dict, world_model, state=None) -> dict: if state is None: batch_size = len(data_id) latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} - action = torch.zeros((batch_size, self._config.num_actions)).to( - self._config.device + action = torch.zeros((batch_size, self._cfg.collect.action_size)).to( + self._device ) else: latent, action = state @@ -283,7 +294,7 @@ def _forward_eval(self, data: dict, world_model, state=None) -> dict: data = data / 255.0 - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step( - latent, action, embed, self._config.collect_dyn_sample + latent, action, embed, self._cfg.collect.collect_dyn_sample ) feat = world_model.dynamics.get_feat(latent) diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 5ab425ffb8..541992ff25 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -132,6 +132,9 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana if _policy is not None: self.reset_policy(_policy) + if self._policy_cfg.type == 'dreamer_command': + self._states = None + self._resets = [False for i in range(self._env_num)] self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) self._policy_output_pool = CachePool('policy_output', self._env_num) # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions @@ -200,6 +203,7 @@ def collect( n_sample: Optional[int] = None, train_iter: int = 0, drop_extra: bool = True, + random_collect: bool = False, record_random_collect: bool = True, policy_kwargs: Optional[dict] = None, level_seeds: Optional[List] = None, @@ -241,8 +245,13 @@ def collect( self._obs_pool.update(obs) if self._transform_obs: obs = to_tensor(obs, dtype=torch.float32) - policy_output = self._policy.forward(obs, **policy_kwargs) - self._policy_output_pool.update(policy_output) + if self._policy_cfg.type == 'dreamer_command' and not random_collect: + policy_output = self._policy.forward(obs, **policy_kwargs, reset=self._resets, state=self._states) + #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} + self._states = [output['state'] for output in policy_output.values()] + else: + policy_output = self._policy.forward(obs, **policy_kwargs) + self._policy_output_pool.update(policy_output) # Interact with env. actions = {env_id: output['action'] for env_id, output in policy_output.items()} actions = to_ndarray(actions) @@ -315,6 +324,9 @@ def collect( # Env reset is done by env_manager automatically self._policy.reset([env_id]) self._reset_stat(env_id) + if self._policy_cfg.type == 'dreamer_command' and not random_collect: + self._resets[env_id] = True + # log if record_random_collect: # default is true, but when random collect, record_random_collect is False self._output_log(train_iter) diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index c1cffc8acd..bd0d2bcf3a 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -42,6 +42,8 @@ collect=dict( n_sample=1, unroll_len=64, + action_size=1, # has to be specified + collect_dyn_sample=True, ), command=dict(), eval=dict(evaluator=dict(eval_freq=10000, )), # w.r.t envstep diff --git a/dizoo/dmc2gym/envs/dmc2gym_env.py b/dizoo/dmc2gym/envs/dmc2gym_env.py index 4aad02eb43..245eed612d 100644 --- a/dizoo/dmc2gym/envs/dmc2gym_env.py +++ b/dizoo/dmc2gym/envs/dmc2gym_env.py @@ -3,6 +3,7 @@ from gym.spaces import Box import numpy as np from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs.common.common_function import affine_transform from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY import dmc2gym @@ -207,6 +208,7 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None: def step(self, action: np.ndarray) -> BaseEnvTimestep: action = action.astype('float32') + action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) obs, rew, done, info = self._env.step(action) self._eval_episode_return += rew if done: From 9055cd228d70bf749083f617337462d973c46e56 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Sun, 25 Jun 2023 12:11:51 +0800 Subject: [PATCH 10/14] fix reset bugs in collector and evaluator --- ding/entry/serial_entry_mbrl.py | 2 +- ding/policy/mbpolicy/dreamer.py | 28 +++++++++++++------ .../collector/interaction_serial_evaluator.py | 13 ++++++++- .../dmc2gym/config/dmc2gym_dreamer_config.py | 10 ++----- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index 2f2cc45856..af6fef2611 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -293,7 +293,7 @@ def serial_pipeline_dreamer( if world_model.should_pretrain() else int(world_model.should_train(collector.envstep)) ) - for _ in range(1): + for _ in range(steps): batch_size = learner.policy.get_attribute('batch_size') post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size) diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index e24667d6ac..6177032103 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -1,5 +1,6 @@ from typing import List, Dict, Any, Tuple, Union from collections import namedtuple +import numpy as np import torch from torch import nn from copy import deepcopy @@ -216,7 +217,7 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N if len(action.shape)==1: action = action.unsqueeze(-1) if any(reset): - mask = 1 - reset + mask = 1 - np.array(reset) for key in latent.keys(): for i in range(latent[key].shape[0]): latent[key][i] *= mask[i] @@ -235,8 +236,7 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() - # to do - # should pass state to next collect + state = (latent, action) output = {"action": action, "logprob": logprob, "state": state} @@ -275,7 +275,7 @@ def _init_eval(self) -> None: self._eval_model = model_wrap(self._model, wrapper_name='base') self._eval_model.reset() - def _forward_eval(self, data: dict, world_model, state=None) -> dict: + def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict: data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -289,9 +289,20 @@ def _forward_eval(self, data: dict, world_model, state=None) -> dict: self._device ) else: - latent, action = state + #state = default_collate(list(state.values())) + latent = default_collate(list(zip(*state))[0]) + action = default_collate(list(zip(*state))[1]) + if len(action.shape)==1: + action = action.unsqueeze(-1) + if any(reset): + mask = 1 - np.array(reset) + for key in latent.keys(): + for i in range(latent[key].shape[0]): + latent[key][i] *= mask[i] + for i in range(len(action)): + action[i] *= mask[i] - data = data / 255.0 - 0.5 + #data = data / 255.0 - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step( latent, action, embed, self._cfg.collect.collect_dyn_sample @@ -303,10 +314,9 @@ def _forward_eval(self, data: dict, world_model, state=None) -> dict: logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() - output = {"action": action, "logprob": logprob} - # to do - # should pass state to next eval + state = (latent, action) + output = {"action": action, "logprob": logprob, "state": state} if self._cuda: output = to_device(output, 'cpu') diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 24da981c1f..d655283221 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -110,6 +110,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy + self._policy_cfg = self._policy.get_attribute('cfg') self._policy.reset() def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: @@ -130,6 +131,9 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self.reset_env(_env) if _policy is not None: self.reset_policy(_policy) + if self._policy_cfg.type == 'dreamer_command': + self._states = None + self._resets = [False for i in range(self._env_num)] self._max_episode_return = float("-inf") self._last_eval_iter = -1 self._end_flag = False @@ -233,7 +237,12 @@ def eval( if render: eval_monitor.update_video(self._env.ready_imgs) - policy_output = self._policy.forward(obs, **policy_kwargs) + if self._policy_cfg.type == 'dreamer_command': + policy_output = self._policy.forward(obs, **policy_kwargs, reset=self._resets, state=self._states) + #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} + self._states = [output['state'] for output in policy_output.values()] + else: + policy_output = self._policy.forward(obs, **policy_kwargs) actions = {i: a['action'] for i, a in policy_output.items()} actions = to_ndarray(actions) timesteps = self._env.step(actions) @@ -259,6 +268,8 @@ def eval( env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() ) ) + if self._policy_cfg.type == 'dreamer_command': + self._resets[env_id] = True envstep_count += 1 duration = self._timer.value episode_return = eval_monitor.get_episode_return() diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index bd0d2bcf3a..79f7f6752a 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -10,7 +10,7 @@ env_id='dmc2gym_cartpole_balance', domain_name='cartpole', task_name='balance', - frame_skip=4, + frame_skip=1, warp_frame=True, scale=True, clip_rewards=False, @@ -46,7 +46,7 @@ collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=10000, )), # w.r.t envstep + eval=dict(evaluator=dict(eval_freq=1, )), # 10000 other=dict( # environment buffer replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), @@ -59,16 +59,10 @@ train_freq=2, # w.r.t envstep cuda=cuda, model=dict( - #elite_size=5, state_size=(3, 64, 64), # has to be specified action_size=1, # has to be specified reward_size=1, - #hidden_size=200, - #use_decay=True, batch_size=16, - #holdout_ratio=0.1, - #max_epochs_since_update=5, - #deterministic_rollout=True, ), ), ) From 0d606f7bdcff9b48ee83e2fe7696cdfb249b1fcd Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Tue, 27 Jun 2023 16:13:01 +0800 Subject: [PATCH 11/14] fix cuda/cpu bugs --- ding/model/template/vac.py | 1 + ding/policy/mbpolicy/dreamer.py | 17 ++++++++--------- ding/torch_utils/network/dreamer.py | 4 +++- .../collector/interaction_serial_evaluator.py | 6 +++--- .../worker/collector/sample_serial_collector.py | 6 +++--- ding/world_model/dreamer.py | 10 ++++++---- dizoo/dmc2gym/config/dmc2gym_dreamer_config.py | 4 ++-- 7 files changed, 26 insertions(+), 22 deletions(-) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 5aa4b83a8d..98b86614cf 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -442,4 +442,5 @@ def __init__( 'LN', # norm 'twohot_symlog', outscale=0.0, + device='cuda' if torch.cuda.is_available() else 'cpu', ) diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 6177032103..4d362dd5bb 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -1,6 +1,5 @@ from typing import List, Dict, Any, Tuple, Union from collections import namedtuple -import numpy as np import torch from torch import nn from copy import deepcopy @@ -212,12 +211,12 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N ) else: #state = default_collate(list(state.values())) - latent = default_collate(list(zip(*state))[0]) - action = default_collate(list(zip(*state))[1]) + latent = to_device(default_collate(list(zip(*state))[0]), self._device) + action = to_device(default_collate(list(zip(*state))[1]), self._device) if len(action.shape)==1: action = action.unsqueeze(-1) - if any(reset): - mask = 1 - np.array(reset) + if reset.any(): + mask = 1 - reset for key in latent.keys(): for i in range(latent[key].shape[0]): latent[key][i] *= mask[i] @@ -290,12 +289,12 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict ) else: #state = default_collate(list(state.values())) - latent = default_collate(list(zip(*state))[0]) - action = default_collate(list(zip(*state))[1]) + latent = to_device(default_collate(list(zip(*state))[0]), self._device) + action = to_device(default_collate(list(zip(*state))[1]), self._device) if len(action.shape)==1: action = action.unsqueeze(-1) - if any(reset): - mask = 1 - np.array(reset) + if reset.any(): + mask = 1 - reset for key in latent.keys(): for i in range(latent[key].shape[0]): latent[key][i] *= mask[i] diff --git a/ding/torch_utils/network/dreamer.py b/ding/torch_utils/network/dreamer.py index 0fa9e9e55f..876f35f811 100644 --- a/ding/torch_utils/network/dreamer.py +++ b/ding/torch_utils/network/dreamer.py @@ -58,6 +58,7 @@ def __init__( dist='normal', std=1.0, outscale=1.0, + device='cpu', ): super(DenseHead, self).__init__() self._shape = (shape, ) if isinstance(shape, int) else shape @@ -69,6 +70,7 @@ def __init__( self._norm = norm self._dist = dist self._std = std + self._device = device self.mlp = MLP( inp_dim, @@ -103,7 +105,7 @@ def forward(self, features, dtype=None): if self._dist == "binary": return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) if self._dist == "twohot_symlog": - return TwoHotDistSymlog(logits=mean) + return TwoHotDistSymlog(logits=mean, device=self._device) raise NotImplementedError(self._dist) diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index d655283221..9f7cb115d7 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -133,7 +133,7 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self.reset_policy(_policy) if self._policy_cfg.type == 'dreamer_command': self._states = None - self._resets = [False for i in range(self._env_num)] + self._resets = np.array([False for i in range(self._env_num)]) self._max_episode_return = float("-inf") self._last_eval_iter = -1 self._end_flag = False @@ -252,6 +252,8 @@ def eval( # If there is an abnormal timestep, reset all the related variables(including this env). self._policy.reset([env_id]) continue + if self._policy_cfg.type == 'dreamer_command': + self._resets[env_id] = t.done if t.done: # Env reset is done by env_manager automatically. if 'figure_path' in self._cfg: @@ -268,8 +270,6 @@ def eval( env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() ) ) - if self._policy_cfg.type == 'dreamer_command': - self._resets[env_id] = True envstep_count += 1 duration = self._timer.value episode_return = eval_monitor.get_episode_return() diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 541992ff25..fc7b09298e 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -134,7 +134,7 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana if self._policy_cfg.type == 'dreamer_command': self._states = None - self._resets = [False for i in range(self._env_num)] + self._resets = np.array([False for i in range(self._env_num)]) self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) self._policy_output_pool = CachePool('policy_output', self._env_num) # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions @@ -271,6 +271,8 @@ def collect( self._reset_stat(env_id) self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) continue + if self._policy_cfg.type == 'dreamer_command' and not random_collect: + self._resets[env_id] = timestep.done if self._policy_cfg.type == 'ngu_command': # for NGU policy transition = self._policy.process_transition( self._obs_pool[env_id], self._policy_output_pool[env_id], timestep, env_id @@ -324,8 +326,6 @@ def collect( # Env reset is done by env_manager automatically self._policy.reset([env_id]) self._reset_stat(env_id) - if self._policy_cfg.type == 'dreamer_command' and not random_collect: - self._resets[env_id] = True # log if record_random_collect: # default is true, but when random collect, record_random_collect is False diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py index 6743c3281a..7287c6ccd7 100644 --- a/ding/world_model/dreamer.py +++ b/ding/world_model/dreamer.py @@ -56,7 +56,7 @@ class DREAMERWorldModel(WorldModel, nn.Module): dyn_min_std=0.1, dyn_cell='gru_layer_norm', unimix_ratio=0.01, - device='cpu', + device='cuda' if torch.cuda.is_available() else 'cpu', ), ) @@ -76,9 +76,6 @@ def __init__(self, cfg, env, tb_logger): self.hidden_size = self._cfg.hidden_size self.batch_size = self._cfg.batch_size - if self._cuda: - self.cuda() - self.encoder = ConvEncoder( self.state_size, hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? @@ -133,6 +130,7 @@ def __init__(self, cfg, env, tb_logger): 'LN', # self._cfg.norm dist=self._cfg.reward_head, outscale=0.0, + device=self._cfg.device, ) if self._cfg.pred_discount: self.heads["discount"] = DenseHead( @@ -143,7 +141,11 @@ def __init__(self, cfg, env, tb_logger): 'SiLU', # self._cfg.act 'LN', # self._cfg.norm dist="binary", + device=self._cfg.device, ) + + if self._cuda: + self.cuda() # to do # grad_clip, weight_decay self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr) diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index 79f7f6752a..c74089ead2 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -20,7 +20,7 @@ collector_env_num=1, evaluator_env_num=8, n_evaluator_episode=8, - stop_value=50, # 800 + stop_value=1000, # 1000 ), policy=dict( cuda=cuda, @@ -46,7 +46,7 @@ collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=1, )), # 10000 + eval=dict(evaluator=dict(eval_freq=10000, )), # 10000 other=dict( # environment buffer replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), From ab89952885d767198828a5baf86ee04f47785ab4 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Thu, 29 Jun 2023 17:59:53 +0800 Subject: [PATCH 12/14] add action repeat wrapper and fix discount lambda_return --- ding/envs/env_wrappers/env_wrappers.py | 32 +++++++++++++++++++ ding/policy/mbpolicy/dreamer.py | 2 ++ ding/policy/mbpolicy/utils.py | 6 ++-- .../dmc2gym/config/dmc2gym_dreamer_config.py | 5 +-- dizoo/dmc2gym/envs/dmc2gym_env.py | 5 ++- 5 files changed, 45 insertions(+), 5 deletions(-) diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index f158ff208d..4f0fc0eb89 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -265,6 +265,38 @@ def reward(self, reward): """ return np.sign(reward) +@ENV_WRAPPER_REGISTRY.register('action_repeat') +class ActionRepeatWrapper(gym.Wrapper): + """ + Overview: + Repeat the action to step with env. + Interface: + ``__init__``, ``step`` + Properties: + - env (:obj:`gym.Env`): the environment to wrap. + - ``action_repeat`` + + """ + + def __init__(self, env, action_repeat=1): + """ + Overview: + Initialize ``self.`` See ``help(type(self))`` for accurate signature; setup the properties. + Arguments: + - env (:obj:`gym.Env`): the environment to wrap. + """ + super().__init__(env) + self.action_repeat = action_repeat + + def step(self, action): + reward = 0 + for _ in range(self.action_repeat): + obs, rew, done, info = self.env.step(action) + reward += rew or 0 + if done: + break + return obs, reward, done, info + @ENV_WRAPPER_REGISTRY.register('delay_reward') class DelayRewardWrapper(gym.Wrapper): diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 4d362dd5bb..439eb087ab 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -145,6 +145,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: # actor-critic update # ==================== self._model.requires_grad_(requires_grad=True) + world_model.requires_grad_(requires_grad=True) loss_dict = { 'critic_loss': value_loss, @@ -154,6 +155,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: norm_dict = self._update(loss_dict) self._model.requires_grad_(requires_grad=False) + world_model.requires_grad_(requires_grad=False) # ============= # after update # ============= diff --git a/ding/policy/mbpolicy/utils.py b/ding/policy/mbpolicy/utils.py index 64385b27d7..38c289811a 100644 --- a/ding/policy/mbpolicy/utils.py +++ b/ding/policy/mbpolicy/utils.py @@ -60,9 +60,11 @@ def step(prev, _): def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent): - if "cont" in world_model.heads: + if "discount" in world_model.heads: inp = world_model.dynamics.get_feat(imag_state) - discount = cfg.discount * world_model.heads["cont"](inp).mean + discount = cfg.discount * world_model.heads["discount"](inp).mean + # TODO whether to detach + discount = discount.detach() else: discount = cfg.discount * torch.ones_like(reward) diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index c74089ead2..219d6d5720 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -14,6 +14,7 @@ warp_frame=True, scale=True, clip_rewards=False, + action_repeat=2, frame_stack=1, from_pixels=True, resize=64, @@ -25,7 +26,7 @@ policy=dict( cuda=cuda, # it is better to put random_collect_size in policy.other - random_collect_size=5000, + random_collect_size=2500, # 5000 model=dict( obs_shape=(3, 64, 64), action_shape=1, @@ -46,7 +47,7 @@ collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=10000, )), # 10000 + eval=dict(evaluator=dict(eval_freq=5000, )), # 10000 other=dict( # environment buffer replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), diff --git a/dizoo/dmc2gym/envs/dmc2gym_env.py b/dizoo/dmc2gym/envs/dmc2gym_env.py index 245eed612d..57f3aff187 100644 --- a/dizoo/dmc2gym/envs/dmc2gym_env.py +++ b/dizoo/dmc2gym/envs/dmc2gym_env.py @@ -7,7 +7,7 @@ from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY import dmc2gym -from ding.envs import WarpFrameWrapper, ScaledFloatFrameWrapper, ClipRewardWrapper, FrameStackWrapper +from ding.envs import WarpFrameWrapper, ScaledFloatFrameWrapper, ClipRewardWrapper, ActionRepeatWrapper, FrameStackWrapper def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: @@ -115,6 +115,7 @@ def __init__(self, cfg: dict = {}) -> None: 'warp_frame': False, 'scale': False, 'clip_rewards': False, + 'action_repeat': 1, "frame_stack": 3, "from_pixels": True, "visualize_reward": False, @@ -161,6 +162,8 @@ def reset(self) -> np.ndarray: self._env = ScaledFloatFrameWrapper(self._env) if self._cfg['clip_rewards']: self._env = ClipRewardWrapper(self._env) + if self._cfg['action_repeat']: + self._env = ActionRepeatWrapper(self._env, self._cfg['action_repeat']) if self._cfg['frame_stack'] > 1: self._env = FrameStackWrapper(self._env, self._cfg['frame_stack']) From d044da68605c9c75dbc3d8b55930213e0c7c8cfb Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Wed, 26 Jul 2023 17:14:40 +0800 Subject: [PATCH 13/14] change replay buffer and add slow target update --- ding/entry/serial_entry_mbrl.py | 3 +- ding/policy/mbpolicy/dreamer.py | 22 ++++- ding/worker/replay_buffer/__init__.py | 2 +- ding/worker/replay_buffer/naive_buffer.py | 96 +++++++++++++++++++ ding/world_model/dreamer.py | 24 +++-- .../cartpole_balance_dreamer_config.py | 91 ++++++++++++++++++ .../cheetah_run/cheetah_run_dreamer_config.py | 91 ++++++++++++++++++ .../dmc2gym/config/dmc2gym_dreamer_config.py | 21 ++-- .../walker_walk/walker_walk_dreamer_config.py | 91 ++++++++++++++++++ 9 files changed, 410 insertions(+), 31 deletions(-) create mode 100644 dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py create mode 100644 dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py create mode 100644 dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index af6fef2611..7ca409f8cf 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -295,7 +295,8 @@ def serial_pipeline_dreamer( ) for _ in range(steps): batch_size = learner.policy.get_attribute('batch_size') - post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size) + batch_length = cfg.policy.learn.batch_length + post, context = world_model.train(env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length) start = post diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 439eb087ab..91f826caa7 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -32,10 +32,13 @@ class DREAMERPolicy(Policy): lambda_=0.95, # (float) Max norm of gradients. grad_clip=100, - learning_rate=0.001, - batch_size=256, + learning_rate=3e-5, + batch_size=16, + batch_length=64, imag_sample=True, slow_value_target=True, + slow_target_update=1, + slow_target_fraction=0.02, discount=0.997, reward_EMA=True, actor_entropy=3e-4, @@ -86,6 +89,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: # log dict log_vars = {} self._learn_model.train() + self._update_slow_target() self._actor.requires_grad_(requires_grad=True) # start is dict of {stoch, deter, logit} @@ -180,6 +184,14 @@ def _update(self, loss_dict): self._optimizer_value.step() return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm} + def _update_slow_target(self): + if self._cfg.learn.slow_value_target: + if self._updates % self._cfg.learn.slow_target_update == 0: + mix = self._cfg.learn.slow_target_fraction + for s, d in zip(self._critic.parameters(), self._slow_value.parameters()): + d.data = mix * s.data + (1 - mix) * d.data + self._updates += 1 + def _state_dict_learn(self) -> Dict[str, Any]: ret = { 'model': self._learn_model.state_dict(), @@ -225,7 +237,7 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N for i in range(len(action)): action[i] *= mask[i] - #data = data / 255.0 - 0.5 + data = data - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step( latent, action, embed, self._cfg.collect.collect_dyn_sample @@ -260,11 +272,11 @@ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple """ transition = { 'obs': obs, - #'next_obs': timestep.obs, 'action': model_output['action'], # TODO(zp) random_collect just have action #'logprob': model_output['logprob'], 'reward': timestep.reward, + 'discount': timestep.info['discount'], 'done': timestep.done, } return transition @@ -303,7 +315,7 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict for i in range(len(action)): action[i] *= mask[i] - #data = data / 255.0 - 0.5 + data = data - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step( latent, action, embed, self._cfg.collect.collect_dyn_sample diff --git a/ding/worker/replay_buffer/__init__.py b/ding/worker/replay_buffer/__init__.py index 8b7009450f..c4f1bf3e87 100644 --- a/ding/worker/replay_buffer/__init__.py +++ b/ding/worker/replay_buffer/__init__.py @@ -1,4 +1,4 @@ from .base_buffer import IBuffer, create_buffer, get_buffer_cls -from .naive_buffer import NaiveReplayBuffer +from .naive_buffer import NaiveReplayBuffer, SequenceReplayBuffer from .advanced_buffer import AdvancedReplayBuffer from .episode_buffer import EpisodeReplayBuffer diff --git a/ding/worker/replay_buffer/naive_buffer.py b/ding/worker/replay_buffer/naive_buffer.py index 78c5bf0f81..a40a456e8b 100644 --- a/ding/worker/replay_buffer/naive_buffer.py +++ b/ding/worker/replay_buffer/naive_buffer.py @@ -2,6 +2,7 @@ import copy from typing import Union, Any, Optional, List import numpy as np +import math import hickle from easydict import EasyDict @@ -465,3 +466,98 @@ def _get_indices(self, size: int, sample_range: slice = None, replace: bool = Fa def update(self, envstep): self._current_buffer_size = self._set_buffer_size(envstep) + + +@BUFFER_REGISTRY.register('sequence') +class SequenceReplayBuffer(NaiveReplayBuffer): + r""" + Overview: + Interface: + start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config + Property: + replay_buffer_size, push_count + """ + + def sample(self, + batch: int, + sequence: int, + cur_learner_iter: int, + sample_range: slice = None, + replace: bool = False) -> Optional[list]: + """ + Overview: + Sample data with length ``size``. + Arguments: + - size (:obj:`int`): The number of the data that will be sampled. + - sequence (:obj:`int`): The length of the sequence of a data that will be sampled. + - cur_learner_iter (:obj:`int`): Learner's current iteration. \ + Not used in naive buffer, but preserved for compatibility. + - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ + means only sample among the last 10 data + - replace (:obj:`bool`): Whether sample with replacement + Returns: + - sample_data (:obj:`list`): A list of data with length ``size``. + """ + if batch == 0: + return [] + can_sample = self._sample_check(batch * sequence, replace) + if not can_sample: + return None + with self._lock: + indices = self._get_indices(batch, sequence, sample_range, replace) + sample_data = self._sample_with_indices(indices, sequence, cur_learner_iter) + self._periodic_thruput_monitor.sample_data_count += len(sample_data) + return sample_data + + def _get_indices(self, size: int, sequence: int, sample_range: slice = None, replace: bool = False) -> list: + r""" + Overview: + Get the sample index list. + Arguments: + - size (:obj:`int`): The number of the data that will be sampled + - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ + means only sample among the last 10 data + Returns: + - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. + """ + assert self._valid_count <= self._replay_buffer_size + if self._valid_count == self._replay_buffer_size: + tail = self._replay_buffer_size + else: + tail = self._tail + episodes = math.ceil(self._valid_count / 500) + batch = 0 + indices = [] + if sample_range is None: + while batch < size: + episode = np.random.choice(episodes) + length = tail - episode*500 if tail - episode*500 < 500 else 500 + available = length - sequence + if available < 1: + continue + list(range(episode*500, episode*500 + available)) + indices.append(np.random.randint(episode*500, episode*500 + available + 1)) + batch += 1 + else: + raise NotImplemented("sample_range is not implemented in this version") + return indices + + def _sample_with_indices(self, indices: List[int], sequence: int, cur_learner_iter: int) -> list: + r""" + Overview: + Sample data with ``indices``. + Arguments: + - indices (:obj:`List[int]`): A list including all the sample indices. + - cur_learner_iter (:obj:`int`): Not used in this method, but preserved for compatibility. + Returns: + - data (:obj:`list`) Sampled data. + """ + data = [] + for idx in indices: + assert self._data[idx] is not None, idx + if self._deepcopy: + copy_data = copy.deepcopy(self._data[idx:idx+sequence]) + else: + copy_data = self._data[idx:idx+sequence] + data.append(copy_data) + return data \ No newline at end of file diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py index 7287c6ccd7..85b129a5e4 100644 --- a/ding/world_model/dreamer.py +++ b/ding/world_model/dreamer.py @@ -3,7 +3,7 @@ import torch from torch import nn -from ding.utils import WORLD_MODEL_REGISTRY +from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts from ding.utils.data import default_collate from ding.model import ConvEncoder from ding.world_model.base_world_model import WorldModel @@ -162,19 +162,17 @@ def should_pretrain(self): return True return False - def train(self, env_buffer, envstep, train_iter, batch_size): + def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): self.last_train_step = envstep - # [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] - data = env_buffer.sample(batch_size, train_iter) - data = default_collate(data) # -> {some_key: T lists}, each list is [B, some_dim] + data = env_buffer.sample(batch_size, batch_length, train_iter) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]] + data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}] + data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim] data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])} - - data['discount'] = 1.0 - data['done'].float() + + data['discount'] = data.get('discount', 1.0 - data['done'].float()) + data['discount'] *= 0.997 data['weight'] = data.get('weight', None) - data['image'] = data['obs'] - #data['obs'] = data['obs'] / 255.0 - 0.5 - #next_obs = data['next_obs'] / 255.0 - 0.5 - #data = {k: v.to(self._cfg.device) for k, v in data.items()} + data['image'] = data['obs'] - 0.5 data = to_device(data, self._cfg.device) if len(data['reward'].shape) == 2: data['reward'] = data['reward'].unsqueeze(-1) @@ -185,9 +183,9 @@ def train(self, env_buffer, envstep, train_iter, batch_size): self.requires_grad_(requires_grad=True) - image = data['obs'].reshape([-1] + list(data['obs'].shape[-3:])) + image = data['image'].reshape([-1] + list(data['image'].shape[-3:])) embed = self.encoder(image) - embed = embed.reshape(list(data['obs'].shape[:-3]) + [embed.shape[-1]]) + embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]]) post, prior = self.dynamics.observe(embed, data["action"]) kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( diff --git a/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py b/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py new file mode 100644 index 0000000000..1304c62352 --- /dev/null +++ b/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py @@ -0,0 +1,91 @@ +from easydict import EasyDict + +from ding.entry import serial_pipeline_dreamer + +cuda = False + +cartpole_balance_dreamer_config = dict( + exp_name='dmc2gym_cartpole_balance_dreamer', + env=dict( + env_id='dmc2gym_cartpole_balance', + domain_name='cartpole', + task_name='balance', + frame_skip=1, + warp_frame=True, + scale=True, + clip_rewards=False, + action_repeat=2, + frame_stack=1, + from_pixels=True, + resize=64, + collector_env_num=1, + evaluator_env_num=1, + n_evaluator_episode=1, + stop_value=1000, # 1000 + ), + policy=dict( + cuda=cuda, + # it is better to put random_collect_size in policy.other + random_collect_size=2500, + model=dict( + obs_shape=(3, 64, 64), + action_shape=1, + actor_dist = 'normal', + ), + learn=dict( + lambda_=0.95, + learning_rate=3e-5, + batch_size=16, + batch_length=64, + imag_sample=True, + discount=0.997, + reward_EMA=True, + ), + collect=dict( + n_sample=1, + unroll_len=1, + action_size=1, # has to be specified + collect_dyn_sample=True, + ), + command=dict(), + eval=dict(evaluator=dict(eval_freq=5000, )), + other=dict( + # environment buffer + replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), + ), + ), + world_model=dict( + pretrain=100, + train_freq=2, + cuda=cuda, + model=dict( + state_size=(3, 64, 64), # has to be specified + action_size=1, # has to be specified + reward_size=1, + batch_size=16, + ), + ), +) + +cartpole_balance_dreamer_config = EasyDict(cartpole_balance_dreamer_config) + +cartpole_balance_create_config = dict( + env=dict( + type='dmc2gym', + import_names=['dizoo.dmc2gym.envs.dmc2gym_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='dreamer', + import_names=['ding.policy.mbpolicy.dreamer'], + ), + replay_buffer=dict(type='sequence', ), + world_model=dict( + type='dreamer', + import_names=['ding.world_model.dreamer'], + ), +) +cartpole_balance_create_config = EasyDict(cartpole_balance_create_config) + +if __name__ == '__main__': + serial_pipeline_dreamer((cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=500000) \ No newline at end of file diff --git a/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py b/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py new file mode 100644 index 0000000000..39c28bed93 --- /dev/null +++ b/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py @@ -0,0 +1,91 @@ +from easydict import EasyDict + +from ding.entry import serial_pipeline_dreamer + +cuda = False + +cheetah_run_dreamer_config = dict( + exp_name='dmc2gym_cheetah_run_dreamer', + env=dict( + env_id='dmc2gym_cheetah_run', + domain_name='cheetah', + task_name='run', + frame_skip=1, + warp_frame=True, + scale=True, + clip_rewards=False, + action_repeat=2, + frame_stack=1, + from_pixels=True, + resize=64, + collector_env_num=1, + evaluator_env_num=1, + n_evaluator_episode=1, + stop_value=1000, # 1000 + ), + policy=dict( + cuda=cuda, + # it is better to put random_collect_size in policy.other + random_collect_size=2500, + model=dict( + obs_shape=(3, 64, 64), + action_shape=6, + actor_dist = 'normal', + ), + learn=dict( + lambda_=0.95, + learning_rate=3e-5, + batch_size=16, + batch_length=64, + imag_sample=True, + discount=0.997, + reward_EMA=True, + ), + collect=dict( + n_sample=1, + unroll_len=1, + action_size=6, # has to be specified + collect_dyn_sample=True, + ), + command=dict(), + eval=dict(evaluator=dict(eval_freq=5000, )), + other=dict( + # environment buffer + replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), + ), + ), + world_model=dict( + pretrain=100, + train_freq=2, + cuda=cuda, + model=dict( + state_size=(3, 64, 64), # has to be specified + action_size=6, # has to be specified + reward_size=1, + batch_size=16, + ), + ), +) + +cheetah_run_dreamer_config = EasyDict(cheetah_run_dreamer_config) + +cheetah_run_create_config = dict( + env=dict( + type='dmc2gym', + import_names=['dizoo.dmc2gym.envs.dmc2gym_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='dreamer', + import_names=['ding.policy.mbpolicy.dreamer'], + ), + replay_buffer=dict(type='sequence', ), + world_model=dict( + type='dreamer', + import_names=['ding.world_model.dreamer'], + ), +) +cheetah_run_create_config = EasyDict(cheetah_run_create_config) + +if __name__ == '__main__': + serial_pipeline_dreamer((cheetah_run_dreamer_config, cheetah_run_create_config), seed=0, max_env_step=500000) \ No newline at end of file diff --git a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py index 219d6d5720..b548ebfdd2 100644 --- a/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py @@ -19,14 +19,14 @@ from_pixels=True, resize=64, collector_env_num=1, - evaluator_env_num=8, - n_evaluator_episode=8, + evaluator_env_num=1, + n_evaluator_episode=1, stop_value=1000, # 1000 ), policy=dict( cuda=cuda, # it is better to put random_collect_size in policy.other - random_collect_size=2500, # 5000 + random_collect_size=2500, model=dict( obs_shape=(3, 64, 64), action_shape=1, @@ -34,30 +34,29 @@ ), learn=dict( lambda_=0.95, - learning_rate=0.001, + learning_rate=3e-5, batch_size=16, + batch_length=64, imag_sample=True, discount=0.997, reward_EMA=True, ), collect=dict( n_sample=1, - unroll_len=64, + unroll_len=1, action_size=1, # has to be specified collect_dyn_sample=True, ), command=dict(), - eval=dict(evaluator=dict(eval_freq=5000, )), # 10000 + eval=dict(evaluator=dict(eval_freq=5000, )), other=dict( # environment buffer - replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60), + replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), ), ), world_model=dict( - #eval_every=10, pretrain=100, - #eval_freq=250, # w.r.t envstep - train_freq=2, # w.r.t envstep + train_freq=2, cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified @@ -80,7 +79,7 @@ type='dreamer', import_names=['ding.policy.mbpolicy.dreamer'], ), - replay_buffer=dict(type='naive', ), + replay_buffer=dict(type='sequence', ), world_model=dict( type='dreamer', import_names=['ding.world_model.dreamer'], diff --git a/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py b/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py new file mode 100644 index 0000000000..ee8f350b51 --- /dev/null +++ b/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py @@ -0,0 +1,91 @@ +from easydict import EasyDict + +from ding.entry import serial_pipeline_dreamer + +cuda = False + +walker_walk_dreamer_config = dict( + exp_name='dmc2gym_walker_walk_dreamer', + env=dict( + env_id='dmc2gym_walker_walk', + domain_name='walker', + task_name='walk', + frame_skip=1, + warp_frame=True, + scale=True, + clip_rewards=False, + action_repeat=2, + frame_stack=1, + from_pixels=True, + resize=64, + collector_env_num=1, + evaluator_env_num=1, + n_evaluator_episode=1, + stop_value=1000, # 1000 + ), + policy=dict( + cuda=cuda, + # it is better to put random_collect_size in policy.other + random_collect_size=2500, + model=dict( + obs_shape=(3, 64, 64), + action_shape=6, + actor_dist = 'normal', + ), + learn=dict( + lambda_=0.95, + learning_rate=3e-5, + batch_size=16, + batch_length=64, + imag_sample=True, + discount=0.997, + reward_EMA=True, + ), + collect=dict( + n_sample=1, + unroll_len=1, + action_size=6, # has to be specified + collect_dyn_sample=True, + ), + command=dict(), + eval=dict(evaluator=dict(eval_freq=5000, )), + other=dict( + # environment buffer + replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), + ), + ), + world_model=dict( + pretrain=100, + train_freq=2, + cuda=cuda, + model=dict( + state_size=(3, 64, 64), # has to be specified + action_size=6, # has to be specified + reward_size=1, + batch_size=16, + ), + ), +) + +walker_walk_dreamer_config = EasyDict(walker_walk_dreamer_config) + +walker_walk_create_config = dict( + env=dict( + type='dmc2gym', + import_names=['dizoo.dmc2gym.envs.dmc2gym_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='dreamer', + import_names=['ding.policy.mbpolicy.dreamer'], + ), + replay_buffer=dict(type='sequence', ), + world_model=dict( + type='dreamer', + import_names=['ding.world_model.dreamer'], + ), +) +walker_walk_create_config = EasyDict(walker_walk_create_config) + +if __name__ == '__main__': + serial_pipeline_dreamer((walker_walk_dreamer_config, walker_walk_create_config), seed=0, max_env_step=500000) \ No newline at end of file From cb919a4d2dc4245f1d33128c76e3bd32713461c3 Mon Sep 17 00:00:00 2001 From: zhangpaipai <1124321458@qq.com> Date: Fri, 28 Jul 2023 16:25:23 +0800 Subject: [PATCH 14/14] polish --- ding/entry/utils.py | 3 +- ding/envs/env_wrappers/env_wrappers.py | 2 +- ding/model/template/vac.py | 2 +- ding/policy/mbpolicy/dreamer.py | 44 +++--- ding/policy/mbpolicy/utils.py | 1 - ding/torch_utils/network/dreamer.py | 138 +++++++++++++----- .../torch_utils/network/tests/test_dreamer.py | 6 +- ding/worker/replay_buffer/naive_buffer.py | 26 ++-- ding/world_model/model/networks.py | 3 +- 9 files changed, 140 insertions(+), 85 deletions(-) diff --git a/ding/entry/utils.py b/ding/entry/utils.py index 38c85889df..8ca1d52c72 100644 --- a/ding/entry/utils.py +++ b/ding/entry/utils.py @@ -60,7 +60,8 @@ def random_collect( new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs) else: new_data = collector.collect( - n_sample=policy_cfg.random_collect_size, random_collect=True, record_random_collect=False, policy_kwargs=collect_kwargs + n_sample=policy_cfg.random_collect_size, random_collect=True, + record_random_collect=False, policy_kwargs=collect_kwargs ) # 'record_random_collect=False' means random collect without output log if postprocess_data_fn is not None: new_data = postprocess_data_fn(new_data) diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 4f0fc0eb89..17961bc2f7 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -184,7 +184,7 @@ def observation(self, frame): sys.exit(1) # to do # channel_first - if frame.shape[0]<10: + if frame.shape[0] < 10: frame = frame.transpose(1, 2, 0) frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) frame = frame.transpose(2, 0, 1) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 98b86614cf..24fe845b94 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -432,7 +432,7 @@ def __init__( actor_temp, outscale=1.0, unimix_ratio=action_unimix_ratio, - ) # action_dist -> action_disc? + ) self.critic = DenseHead( feat_size, # pytorch version (255, ), diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 91f826caa7..43d3b88619 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -90,7 +90,7 @@ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: log_vars = {} self._learn_model.train() self._update_slow_target() - + self._actor.requires_grad_(requires_grad=True) # start is dict of {stoch, deter, logit} if self._cuda: @@ -191,7 +191,7 @@ def _update_slow_target(self): for s, d in zip(self._critic.parameters(), self._slow_value.parameters()): d.data = mix * s.data + (1 - mix) * d.data self._updates += 1 - + def _state_dict_learn(self) -> Dict[str, Any]: ret = { 'model': self._learn_model.state_dict(), @@ -216,18 +216,16 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N if self._cuda: data = to_device(data, self._device) self._collect_model.eval() - + if state is None: batch_size = len(data_id) latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} - action = torch.zeros((batch_size, self._cfg.collect.action_size)).to( - self._device - ) + action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) else: #state = default_collate(list(state.values())) latent = to_device(default_collate(list(zip(*state))[0]), self._device) action = to_device(default_collate(list(zip(*state))[1]), self._device) - if len(action.shape)==1: + if len(action.shape) == 1: action = action.unsqueeze(-1) if reset.any(): mask = 1 - reset @@ -236,28 +234,26 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N latent[key][i] *= mask[i] for i in range(len(action)): action[i] *= mask[i] - + data = data - 0.5 embed = world_model.encoder(data) - latent, _ = world_model.dynamics.obs_step( - latent, action, embed, self._cfg.collect.collect_dyn_sample - ) + latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) feat = world_model.dynamics.get_feat(latent) - + actor = self._actor(feat) action = actor.sample() logprob = actor.log_prob(action) latent = {k: v.detach() for k, v in latent.items()} action = action.detach() - + state = (latent, action) output = {"action": action, "logprob": logprob, "state": state} - + if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: r""" Overview: @@ -294,18 +290,16 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict if self._cuda: data = to_device(data, self._device) self._eval_model.eval() - + if state is None: batch_size = len(data_id) latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} - action = torch.zeros((batch_size, self._cfg.collect.action_size)).to( - self._device - ) + action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) else: #state = default_collate(list(state.values())) latent = to_device(default_collate(list(zip(*state))[0]), self._device) action = to_device(default_collate(list(zip(*state))[1]), self._device) - if len(action.shape)==1: + if len(action.shape) == 1: action = action.unsqueeze(-1) if reset.any(): mask = 1 - reset @@ -314,14 +308,12 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict latent[key][i] *= mask[i] for i in range(len(action)): action[i] *= mask[i] - + data = data - 0.5 embed = world_model.encoder(data) - latent, _ = world_model.dynamics.obs_step( - latent, action, embed, self._cfg.collect.collect_dyn_sample - ) + latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) feat = world_model.dynamics.get_feat(latent) - + actor = self._actor(feat) action = actor.mode() logprob = actor.log_prob(action) @@ -330,7 +322,7 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict state = (latent, action) output = {"action": action, "logprob": logprob, "state": state} - + if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) diff --git a/ding/policy/mbpolicy/utils.py b/ding/policy/mbpolicy/utils.py index 38c289811a..b17c36e47f 100644 --- a/ding/policy/mbpolicy/utils.py +++ b/ding/policy/mbpolicy/utils.py @@ -95,7 +95,6 @@ def compute_actor_loss( policy = actor(inp) actor_ent = policy.entropy() # Q-val for actor is not transformed using symlog - # target = torch.stack(target, dim=1) if cfg.reward_EMA: offset, scale = reward_ema(target) normed_target = (target - offset) / scale diff --git a/ding/torch_utils/network/dreamer.py b/ding/torch_utils/network/dreamer.py index 876f35f811..6f48ce085d 100644 --- a/ding/torch_utils/network/dreamer.py +++ b/ding/torch_utils/network/dreamer.py @@ -6,6 +6,8 @@ import torch.nn.functional as F from torch import distributions as torchd from ding.torch_utils import MLP +from ding.rl_utils import symlog, inv_symlog + class Conv2dSame(torch.nn.Conv2d): @@ -46,10 +48,16 @@ def forward(self, x): class DenseHead(nn.Module): + """ + Overview: + DenseHead Network for value head, reward head, and discount head of dreamerv3. + Interface: + ``__init__``, ``forward`` + """ def __init__( self, - inp_dim, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter + inp_dim, shape, # (255,) layer_num, units, # 512 @@ -90,7 +98,7 @@ def __init__( self.std_layer = nn.Linear(self._units, np.prod(self._shape)) self.std_layer.apply(uniform_weight_init(outscale)) - def forward(self, features, dtype=None): + def forward(self, features): x = features out = self.mlp(x) # (batch, time, _units=512) mean = self.mean_layer(out) # (batch, time, 255) @@ -100,16 +108,22 @@ def forward(self, features, dtype=None): std = self._std if self._dist == "normal": return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) - if self._dist == "huber": + elif self._dist == "huber": return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) - if self._dist == "binary": + elif self._dist == "binary": return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) - if self._dist == "twohot_symlog": + elif self._dist == "twohot_symlog": return TwoHotDistSymlog(logits=mean, device=self._device) raise NotImplementedError(self._dist) class ActionHead(nn.Module): + """ + Overview: + ActionHead Network for action head of dreamerv3. + Interface: + ``__init__``, ``forward`` + """ def __init__( self, @@ -158,7 +172,7 @@ def __init__( self._dist_layer = nn.Linear(self._units, self._size) self._dist_layer.apply(uniform_weight_init(outscale)) - def __call__(self, features, dtype=None): + def forward(self, features): x = features x = self._pre_layers(x) if self._dist == "tanh_normal": @@ -206,29 +220,20 @@ def __call__(self, features, dtype=None): else: raise NotImplementedError(self._dist) return dist - - -def symlog(x): - return torch.sign(x) * torch.log(torch.abs(x) + 1.0) - - -def symexp(x): - return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) class SampleDist: + """ + Overview: + A kind of sample Dist for ActionHead of dreamerv3. + Interface: + ``__init__``, ``mean``, ``mode``, ``entropy`` + """ def __init__(self, dist, samples=100): self._dist = dist self._samples = samples - @property - def name(self): - return 'SampleDist' - - def __getattr__(self, name): - return getattr(self._dist, name) - def mean(self): samples = self._dist.sample(self._samples) return torch.mean(samples, 0) @@ -245,6 +250,12 @@ def entropy(self): class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): + """ + Overview: + A kind of onehot Dist for dreamerv3. + Interface: + ``__init__``, ``mode``, ``sample`` + """ def __init__(self, logits=None, probs=None, unimix_ratio=0.0): if logits is not None and unimix_ratio > 0.0: @@ -270,7 +281,13 @@ def sample(self, sample_shape=(), seed=None): return sample -class TwoHotDistSymlog(): +class TwoHotDistSymlog: + """ + Overview: + A kind of twohotsymlog Dist for dreamerv3. + Interface: + ``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target`` + """ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): self.logits = logits @@ -279,13 +296,12 @@ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): self.width = (self.buckets[-1] - self.buckets[0]) / 255 def mean(self): - print("mean called") _mean = self.probs * self.buckets - return symexp(torch.sum(_mean, dim=-1, keepdim=True)) + return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True)) def mode(self): _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True)) # Inside OneHotCategorical, log_prob is calculated using only max element in targets def log_prob(self, x): @@ -316,20 +332,26 @@ def log_prob_target(self, target): return (target * log_pred).sum(-1) -class SymlogDist(): +class SymlogDist: + """ + Overview: + A kind of Symlog Dist for dreamerv3. + Interface: + ``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob`` + """ - def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): + def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): self._mode = mode self._dist = dist - self._agg = agg + self._aggregation = aggregation self._tol = tol self._dim_to_reduce = dim_to_reduce def mode(self): - return symexp(self._mode) + return inv_symlog(self._mode) def mean(self): - return symexp(self._mode) + return inv_symlog(self._mode) def log_prob(self, value): assert self._mode.shape == value.shape @@ -341,16 +363,22 @@ def log_prob(self, value): distance = torch.where(distance < self._tol, 0, distance) else: raise NotImplementedError(self._dist) - if self._agg == 'mean': + if self._aggregation == 'mean': loss = distance.mean(self._dim_to_reduce) - elif self._agg == 'sum': + elif self._aggregation == 'sum': loss = distance.sum(self._dim_to_reduce) else: - raise NotImplementedError(self._agg) + raise NotImplementedError(self._aggregation) return -loss class ContDist: + """ + Overview: + A kind of ordinary Dist for dreamerv3. + Interface: + ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` + """ def __init__(self, dist=None): super().__init__() @@ -374,6 +402,12 @@ def log_prob(self, x): class Bernoulli: + """ + Overview: + A kind of Bernoulli Dist for dreamerv3. + Interface: + ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` + """ def __init__(self, dist=None): super().__init__() @@ -402,6 +436,12 @@ def log_prob(self, x): class UnnormalizedHuber(torchd.normal.Normal): + """ + Overview: + A kind of UnnormalizedHuber Dist for dreamerv3. + Interface: + ``__init__``, ``mode``, ``log_prob`` + """ def __init__(self, loc, scale, threshold=1, **kwargs): super().__init__(loc, scale, **kwargs) @@ -415,6 +455,12 @@ def mode(self): class SafeTruncatedNormal(torchd.normal.Normal): + """ + Overview: + A kind of SafeTruncatedNormal Dist for dreamerv3. + Interface: + ``__init__``, ``sample`` + """ def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): super().__init__(loc, scale) @@ -434,6 +480,12 @@ def sample(self, sample_shape): class TanhBijector(torchd.Transform): + """ + Overview: + A kind of TanhBijector Dist for dreamerv3. + Interface: + ``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian`` + """ def __init__(self, validate_args=False, name='tanh'): super().__init__() @@ -452,40 +504,44 @@ def _forward_log_det_jacobian(self, x): def static_scan(fn, inputs, start): - last = start # {logit:[batch_size, self._stoch, self._discrete], stoch:[batch_size, self._stoch, self._discrete], deter:[batch_size, self._deter]} + last = start # {logit, stoch, deter:[batch_size, self._deter]} indices = range(inputs[0].shape[0]) flag = True for index in indices: inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096)) last = fn(last, *inp(index)) # post, prior if flag: - if type(last) == type({}): + if isinstance(last, dict): outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} else: outputs = [] for _last in last: - if type(_last) == type({}): + if isinstance(_last, dict): outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) else: outputs.append(_last.clone().unsqueeze(0)) flag = False else: - if type(last) == type({}): + if isinstance(last, dict): for key in last.keys(): outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) else: for j in range(len(outputs)): - if type(last[j]) == type({}): + if isinstance(last[j], dict): for key in last[j].keys(): outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) else: outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) - if type(last) == type({}): + if isinstance(last, dict): outputs = [outputs] return outputs def weight_init(m): + """ + Overview: + weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm. + """ if isinstance(m, nn.Linear): in_num = m.in_features out_num = m.out_features @@ -512,6 +568,10 @@ def weight_init(m): def uniform_weight_init(given_scale): + """ + Overview: + weight_init for Linear and LayerNorm. + """ def f(m): if isinstance(m, nn.Linear): diff --git a/ding/torch_utils/network/tests/test_dreamer.py b/ding/torch_utils/network/tests/test_dreamer.py index 77ca133cc0..accfb1d8c5 100644 --- a/ding/torch_utils/network/tests/test_dreamer.py +++ b/ding/torch_utils/network/tests/test_dreamer.py @@ -3,8 +3,8 @@ import torch from torch import distributions as torchd from itertools import product -from ding.torch_utils.network.dreamer import DenseHead, SampleDist, OneHotDist, TwoHotDistSymlog, SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init - +from ding.torch_utils.network.dreamer import DenseHead, SampleDist, OneHotDist, TwoHotDistSymlog, \ + SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init # arguments shape = [255, (255, ), ()] @@ -29,7 +29,7 @@ def test_DenseHead(shape, dist): B, time = 16, 64 mean = torch.randn(B, time, 255) std = 1.0 -a = torch.randn(B, time, 1) # or torch.randn(B, time, 255) +a = torch.randn(B, time, 1) # or torch.randn(B, time, 255) sample_shape = torch.Size([]) diff --git a/ding/worker/replay_buffer/naive_buffer.py b/ding/worker/replay_buffer/naive_buffer.py index a40a456e8b..db06b2c6b8 100644 --- a/ding/worker/replay_buffer/naive_buffer.py +++ b/ding/worker/replay_buffer/naive_buffer.py @@ -478,12 +478,14 @@ class SequenceReplayBuffer(NaiveReplayBuffer): replay_buffer_size, push_count """ - def sample(self, - batch: int, - sequence: int, - cur_learner_iter: int, - sample_range: slice = None, - replace: bool = False) -> Optional[list]: + def sample( + self, + batch: int, + sequence: int, + cur_learner_iter: int, + sample_range: slice = None, + replace: bool = False + ) -> Optional[list]: """ Overview: Sample data with length ``size``. @@ -531,12 +533,12 @@ def _get_indices(self, size: int, sequence: int, sample_range: slice = None, rep if sample_range is None: while batch < size: episode = np.random.choice(episodes) - length = tail - episode*500 if tail - episode*500 < 500 else 500 + length = tail - episode * 500 if tail - episode * 500 < 500 else 500 available = length - sequence if available < 1: continue - list(range(episode*500, episode*500 + available)) - indices.append(np.random.randint(episode*500, episode*500 + available + 1)) + list(range(episode * 500, episode * 500 + available)) + indices.append(np.random.randint(episode * 500, episode * 500 + available + 1)) batch += 1 else: raise NotImplemented("sample_range is not implemented in this version") @@ -556,8 +558,8 @@ def _sample_with_indices(self, indices: List[int], sequence: int, cur_learner_it for idx in indices: assert self._data[idx] is not None, idx if self._deepcopy: - copy_data = copy.deepcopy(self._data[idx:idx+sequence]) + copy_data = copy.deepcopy(self._data[idx:idx + sequence]) else: - copy_data = self._data[idx:idx+sequence] + copy_data = self._data[idx:idx + sequence] data.append(copy_data) - return data \ No newline at end of file + return data diff --git a/ding/world_model/model/networks.py b/ding/world_model/model/networks.py index c753b16351..091fa4f827 100644 --- a/ding/world_model/model/networks.py +++ b/ding/world_model/model/networks.py @@ -6,7 +6,8 @@ import torch.nn.functional as F from torch import distributions as torchd -from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, OneHotDist, ContDist, SymlogDist, static_scan, DreamerLayerNorm +from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \ + OneHotDist, ContDist, SymlogDist, DreamerLayerNorm class RSSM(nn.Module):