Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(cy): add dreamerV3 + MiniGrid code #725

Merged
merged 12 commits into from
Feb 1, 2024
4 changes: 1 addition & 3 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ class DREAMERVAC(nn.Module):

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
dyn_stoch=32,
dyn_deter=512,
Expand All @@ -391,9 +390,8 @@ def __init__(
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(DREAMERVAC, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.obs_shape, self.action_shape = obs_shape, action_shape
self.action_shape = action_shape

if dyn_discrete:
feat_size = dyn_stoch * dyn_discrete + dyn_deter
Expand Down
18 changes: 15 additions & 3 deletions ding/policy/mbpolicy/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N
for i in range(len(action)):
action[i] *= mask[i]

data = data - 0.5
if type(world_model.state_size) != int and len(world_model.state_size) == 3:
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
data = data - 0.5
embed = world_model.encoder(data)
latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
feat = world_model.dynamics.get_feat(latent)
Expand All @@ -247,11 +248,16 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N
action = action.detach()

state = (latent, action)
if world_model.action_type == 'discrete':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an assertion at the first occurrence of this string comparison

action = torch.where(action == 1)[1]
output = {"action": action, "logprob": logprob, "state": state}

if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
if world_model.action_type == 'discrete':
for l in range(len(output)):
output[l]['action'] = output[l]['action'].squeeze(0)
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
return {i: d for i, d in zip(data_id, output)}

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
Expand All @@ -272,7 +278,7 @@ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple
# TODO(zp) random_collect just have action
#'logprob': model_output['logprob'],
'reward': timestep.reward,
'discount': timestep.info['discount'],
'discount': 1. - timestep.done, # timestep.info['discount'],
'done': timestep.done,
}
return transition
Expand Down Expand Up @@ -309,7 +315,8 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict
for i in range(len(action)):
action[i] *= mask[i]

data = data - 0.5
if type(world_model.state_size) != int and len(world_model.state_size) == 3:
data = data - 0.5
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
embed = world_model.encoder(data)
latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
feat = world_model.dynamics.get_feat(latent)
Expand All @@ -321,11 +328,16 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict
action = action.detach()

state = (latent, action)
if world_model.action_type == 'discrete':
action = torch.where(action == 1)[1]
output = {"action": action, "logprob": logprob, "state": state}

if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
if world_model.action_type == 'discrete':
for l in range(len(output)):
output[l]['action'] = output[l]['action'].squeeze(0)
return {i: d for i, d in zip(data_id, output)}

def _monitor_vars_learn(self) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions ding/torch_utils/network/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def forward(self, features):
elif self._dist == "binary":
return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)))
elif self._dist == "twohot_symlog":
return TwoHotDistSymlog(logits=mean, device=self._device)
return TwoHotDistSymlog(logits=mean, low=-1., high=1., device=self._device)
raise NotImplementedError(self._dist)


Expand Down Expand Up @@ -475,8 +475,8 @@ def log_prob(self, x):
above = torch.clip(above, 0, len(self.buckets) - 1)
equal = (below == above)

dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
dist_to_below = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
Expand Down
91 changes: 67 additions & 24 deletions ding/world_model/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts
from ding.utils.data import default_collate
from ding.model import ConvEncoder
from ding.model import ConvEncoder, FCEncoder
from ding.world_model.base_world_model import WorldModel
from ding.world_model.model.networks import RSSM, ConvDecoder
from ding.torch_utils import to_device
Expand Down Expand Up @@ -37,6 +37,7 @@ class DREAMERWorldModel(WorldModel, nn.Module):
norm='LayerNorm',
grad_heads=['image', 'reward', 'discount'],
units=512,
image_dec_layers=2,
reward_layers=2,
discount_layers=2,
value_layers=2,
Expand Down Expand Up @@ -72,21 +73,26 @@ def __init__(self, cfg, env, tb_logger):
self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm
self.state_size = self._cfg.state_size
self.action_size = self._cfg.action_size
self.action_type = self._cfg.action_type
self.reward_size = self._cfg.reward_size
self.hidden_size = self._cfg.hidden_size
self.batch_size = self._cfg.batch_size
if type(self.state_size) == int or len(self.state_size) == 1:
self.encoder = FCEncoder(self.state_size, self._cfg.encoder_hidden_size_list, activation=torch.nn.SiLU())
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
self.embed_size = self._cfg.encoder_hidden_size_list[-1]
elif len(self.state_size) == 3:
self.encoder = ConvEncoder(
self.state_size,
hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128?
activation=torch.nn.SiLU(),
kernel_size=self._cfg.encoder_kernels,
layer_norm=True
)
self.embed_size = (
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth *
2 ** (len(self._cfg.encoder_kernels) - 1)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add else branch for better error hints


self.encoder = ConvEncoder(
self.state_size,
hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128?
activation=torch.nn.SiLU(),
kernel_size=self._cfg.encoder_kernels,
layer_norm=True
)
self.embed_size = (
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth *
2 ** (len(self._cfg.encoder_kernels) - 1)
)
self.dynamics = RSSM(
self._cfg.dyn_stoch,
self._cfg.dyn_deter,
Expand All @@ -113,14 +119,28 @@ def __init__(self, cfg, env, tb_logger):
feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter
else:
feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter
self.heads["image"] = ConvDecoder(
feat_size, # pytorch version
self._cfg.cnn_depth,
self._cfg.act,
self._cfg.norm,
self.state_size,
self._cfg.decoder_kernels,
)

if type(self.state_size) == int or len(self.state_size) == 1:
self.heads['image'] = DenseHead(
feat_size,
(self.state_size, ),
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
self._cfg.image_dec_layers,
self._cfg.units,
'SiLU', # self._cfg.act
'LN', # self._cfg.norm
dist='binary',
outscale=0.0,
device=self._cfg.device,
)
elif len(self.state_size) == 3:
self.heads["image"] = ConvDecoder(
feat_size, # pytorch version
self._cfg.cnn_depth,
self._cfg.act,
self._cfg.norm,
self.state_size,
self._cfg.decoder_kernels,
)
self.heads["reward"] = DenseHead(
feat_size, # dyn_stoch * dyn_discrete + dyn_deter
(255, ),
Expand Down Expand Up @@ -172,9 +192,32 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])}

data['discount'] = data.get('discount', 1.0 - data['done'].float())
data['discount'] *= 0.997
# data['discount'] *= 0.997
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
data['weight'] = data.get('weight', None)
data['image'] = data['obs'] - 0.5
if type(self.state_size) != int and len(self.state_size) == 3:
data['image'] = data['obs'] - 0.5
else:
data['image'] = data['obs']
if self.action_type == 'continuous':
data['action'] *= (1.0 / torch.clip(torch.abs(data['action']), min=1.0))
else:

def make_one_hot(x, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [bs, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [bs, num_classes, *]
"""
x = x.to(torch.int64)
shape = (*tuple(x.shape), num_classes)
x = x.unsqueeze(-1)
res = torch.zeros(shape).to(x)
res = res.scatter_(-1, x, 1)
return res.float()

data['action'] = make_one_hot(data['action'], self.action_size)
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
data = to_device(data, self._cfg.device)
if len(data['reward'].shape) == 2:
data['reward'] = data['reward'].unsqueeze(-1)
Expand All @@ -185,9 +228,9 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):

self.requires_grad_(requires_grad=True)

image = data['image'].reshape([-1] + list(data['image'].shape[-3:]))
image = data['image'].reshape([-1] + list(data['image'].shape[2:]))
embed = self.encoder(image)
embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]])
embed = embed.reshape(list(data['image'].shape[:2]) + [embed.shape[-1]])

post, prior = self.dynamics.observe(embed, data["action"])
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
Expand Down
14 changes: 8 additions & 6 deletions ding/world_model/model/networks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
import numpy as np
from typing import Optional, Dict, Union, List

import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd

from ding.utils import SequenceType
from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \
OneHotDist, ContDist, SymlogDist, DreamerLayerNorm

Expand Down Expand Up @@ -179,7 +180,7 @@ def get_dist(self, state, dtype=None):
def obs_step(self, prev_state, prev_action, embed, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
# prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
Expand All @@ -202,7 +203,7 @@ def obs_step(self, prev_state, prev_action, embed, sample=True):
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num)
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
# prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
Expand Down Expand Up @@ -282,8 +283,9 @@ def kl_loss(self, post, prior, forward, free, lscale, rscale):
dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist,
dist(rhs) if self._discrete else dist(rhs)._dist,
)
loss_lhs = torch.clip(torch.mean(value_lhs), min=free)
loss_rhs = torch.clip(torch.mean(value_rhs), min=free)
# free bits
loss_lhs = torch.mean(torch.clip(value_lhs, min=free))
loss_rhs = torch.mean(torch.clip(value_rhs, min=free))
loss = lscale * loss_lhs + rscale * loss_rhs

return loss, value, loss_lhs, loss_rhs
Expand Down Expand Up @@ -357,7 +359,7 @@ def calc_same_pad(self, k, s, d):
outpad = pad * 2 - val
return pad, outpad

def __call__(self, features, dtype=None):
def __call__(self, features):
x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter]
x = x.reshape([-1, 4, 4, self._embed_size // 16])
x = x.permute(0, 3, 1, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# it is better to put random_collect_size in policy.other
random_collect_size=2500,
model=dict(
obs_shape=(3, 64, 64),
action_shape=6,
actor_dist='normal',
),
Expand Down Expand Up @@ -61,6 +60,7 @@
model=dict(
state_size=(3, 64, 64), # has to be specified
action_size=6, # has to be specified
action_type='continuous',
reward_size=1,
batch_size=16,
),
Expand Down
Loading
Loading