From 304c7ee8c798dc64b604465f8f4cbb13cc071da0 Mon Sep 17 00:00:00 2001 From: karroyan Date: Tue, 14 Jun 2022 15:39:28 +0800 Subject: [PATCH 1/6] fix import path error in lunarlander --- dizoo/box2d/lunarlander/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dizoo/box2d/lunarlander/config/__init__.py b/dizoo/box2d/lunarlander/config/__init__.py index 687e1c5269..200970a208 100644 --- a/dizoo/box2d/lunarlander/config/__init__.py +++ b/dizoo/box2d/lunarlander/config/__init__.py @@ -1,5 +1,5 @@ from .lunarlander_dqn_config import lunarlander_dqn_config, lunarlander_dqn_create_config -from .lunarlander_dqn_gail_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_config +from .lunarlander_gail_dqn_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_config from .lunarlander_dqfd_config import lunarlander_dqfd_config, lunarlander_dqfd_create_config from .lunarlander_qrdqn_config import lunarlander_qrdqn_config, lunarlander_qrdqn_create_config from .lunarlander_trex_dqn_config import lunarlander_trex_dqn_config, lunarlander_trex_dqn_create_config From fba4401ee9a76ea09008ea9774f9d42089e1b1e3 Mon Sep 17 00:00:00 2001 From: karroyan Date: Thu, 2 Feb 2023 11:43:52 +0800 Subject: [PATCH 2/6] add procedure cloning model --- ding/model/template/__init__.py | 1 + ding/model/template/procedure_cloning.py | 78 +++++++++++++++++++ .../template/tests/test_procedure_cloning.py | 31 ++++++++ 3 files changed, 110 insertions(+) create mode 100644 ding/model/template/procedure_cloning.py create mode 100644 ding/model/template/tests/test_procedure_cloning.py diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index d4907a510d..11f7aa35b5 100644 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -22,3 +22,4 @@ from .madqn import MADQN from .vae import VanillaVAE from .decision_transformer import DecisionTransformer +from .procedure_cloning import ProcedureCloning diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py new file mode 100644 index 0000000000..996c8e04cc --- /dev/null +++ b/ding/model/template/procedure_cloning.py @@ -0,0 +1,78 @@ +from typing import Optional +from ding.utils import MODEL_REGISTRY +from typing import Tuple +from ding.torch_utils.network.transformer import Attention +from ..common import FCEncoder, ConvEncoder +from ding.torch_utils.network.nn_module import fc_block, build_normalization +from ding.utils import SequenceType +import torch +import torch.nn as nn + + +@MODEL_REGISTRY.register('pc') +class ProcedureCloning(nn.Module): + + def __init__( + self, + obs_shape: SequenceType, + action_dim: int, + cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], + cnn_activation: Optional[nn.Module] = nn.ReLU(), + cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], + cnn_stride: SequenceType = [1, 1, 1, 1, 1], + cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'], + mlp_hidden_list: SequenceType = [256, 256], + mlp_activation: Optional[nn.Module] = nn.ReLU(), + att_heads: int = 8, + att_hidden: int = 128, + n_att: int = 4, + n_feedforward: int = 2, + feedforward_hidden: int = 256, + drop_p: float = 0.5 + ) -> None: + super().__init__() + + #Conv Encoder + self.embed_state = ConvEncoder( + obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding + ) + self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation) + + self.cnn_hidden_list = cnn_hidden_list + + assert cnn_hidden_list[-1] == mlp_hidden_list[-1] + layers = [] + for i in range(n_att): + layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) + layers.append(build_normalization('LN')(att_hidden)) + for j in range(n_feedforward): + if j == 0: + layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) + else: + layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) + self.layernorm2 = build_normalization('LN')(feedforward_hidden) + self.transformer = nn.Sequential(*layers) + + self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) + self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) + + def forward(self, states: torch.Tensor, goals: torch.Tensor, + actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + B, T, _ = actions.shape + + # shape: (B, h_dim) + state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1]) + goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1]) + # shape: (B, context_len, h_dim) + actions_embeddings = self.embed_action(actions) + + h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1) + print(h.shape) + h = self.transformer(h) + h = h.reshape(B, T + 2, self.cnn_hidden_list[-1]) + + goal_preds = self.predict_goal(h[:, 0, :]) + action_preds = self.predict_action(h[:, 1:, :]) + + return goal_preds, action_preds diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py new file mode 100644 index 0000000000..c3d3c896db --- /dev/null +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -0,0 +1,31 @@ +import torch +import numpy as np +import pytest +from itertools import product + +from ding.model.template import ProcedureCloning +from ding.torch_utils import is_differentiable +from ding.utils import squeeze + +B = 4 +T = 15 +obs_shape = [(64, 64, 3)] +action_dim = [9] +obs_embeddings = 256 +args = list(product(*[obs_shape, action_dim])) + + +@pytest.mark.unittest +@pytest.mark.parametrize('obs_shape, action_dim', args) +class TestProcedureCloning: + + def test_procedure_cloning(self, obs_shape, action_dim): + inputs = {'states': torch.randn(B, *obs_shape), 'goals': torch.randn(B, *obs_shape),\ + 'actions': torch.randn(B, T, action_dim)} + model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim) + + print(model) + + goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) + assert goal_preds.shape == (B, obs_embeddings) + assert action_preds.shape == (B, T + 1, action_dim) From d523f4a85996dbdc8d9b458d46531dd37573d3b4 Mon Sep 17 00:00:00 2001 From: karroyan Date: Thu, 2 Feb 2023 12:40:44 +0800 Subject: [PATCH 3/6] modify attention and feedforward network order --- ding/model/template/procedure_cloning.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 996c8e04cc..5c1cf8c149 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -43,14 +43,17 @@ def __init__( assert cnn_hidden_list[-1] == mlp_hidden_list[-1] layers = [] for i in range(n_att): - layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) + if i == 0: + layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) + else: + layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) layers.append(build_normalization('LN')(att_hidden)) - for j in range(n_feedforward): - if j == 0: - layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) - else: - layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) - self.layernorm2 = build_normalization('LN')(feedforward_hidden) + for i in range(n_feedforward): + if i == 0: + layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) + else: + layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) + self.layernorm2 = build_normalization('LN')(feedforward_hidden) self.transformer = nn.Sequential(*layers) self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) From 1d83d7cfa88405c0c1bfb6c66c56671c78fcfeee Mon Sep 17 00:00:00 2001 From: karroyan Date: Mon, 6 Feb 2023 14:36:49 +0800 Subject: [PATCH 4/6] add casual mask --- ding/model/template/procedure_cloning.py | 43 ++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index 5c1cf8c149..f987c4dc22 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -9,6 +9,39 @@ import torch.nn as nn +class Block(nn.Module): + + def __init__(self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, \ + feedforward_hidden: int, n_feedforward: int) -> None: + super().__init__() + self.n_att = n_att + self.n_feedforward = n_feedforward + self.attention_layer = [] + + self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att + self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) + for i in range(n_att - 1): + self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) + + self.att_drop = nn.Dropout(drop_p) + + self.fc_blocks = [] + self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) + for i in range(n_feedforward - 1): + self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) + self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward) + self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T) + + def forward(self, x: torch.Tensor): + for i in range(self.n_att): + x = self.att_drop(self.attention_layer[i](x, self.mask)) + x = self.norm_layer[i](x) + for i in range(self.n_feedforward): + x = self.fc_blocks[i](x) + x = self.norm_layer[i + self.n_att](x) + return x + + @MODEL_REGISTRY.register('pc') class ProcedureCloning(nn.Module): @@ -28,7 +61,9 @@ def __init__( n_att: int = 4, n_feedforward: int = 2, feedforward_hidden: int = 256, - drop_p: float = 0.5 + drop_p: float = 0.5, + augment: bool = True, + max_T: int = 17 ) -> None: super().__init__() @@ -39,6 +74,7 @@ def __init__( self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation) self.cnn_hidden_list = cnn_hidden_list + self.augment = augment assert cnn_hidden_list[-1] == mlp_hidden_list[-1] layers = [] @@ -54,7 +90,10 @@ def __init__( else: layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) self.layernorm2 = build_normalization('LN')(feedforward_hidden) - self.transformer = nn.Sequential(*layers) + + self.transformer = Block( + cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward + ) self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) From d9c65661efad7c5c42b347d5eef0655d84ea706a Mon Sep 17 00:00:00 2001 From: karroyan Date: Wed, 8 Feb 2023 16:15:17 +0800 Subject: [PATCH 5/6] polish --- ding/model/template/procedure_cloning.py | 13 +++++-------- ding/model/template/tests/test_procedure_cloning.py | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index f987c4dc22..b023ec615f 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -1,12 +1,10 @@ -from typing import Optional -from ding.utils import MODEL_REGISTRY -from typing import Tuple -from ding.torch_utils.network.transformer import Attention -from ..common import FCEncoder, ConvEncoder -from ding.torch_utils.network.nn_module import fc_block, build_normalization -from ding.utils import SequenceType +from typing import Optional, Tuple import torch import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from ding.torch_utils.network.transformer import Attention +from ding.torch_utils.network.nn_module import fc_block, build_normalization +from ..common import FCEncoder, ConvEncoder class Block(nn.Module): @@ -110,7 +108,6 @@ def forward(self, states: torch.Tensor, goals: torch.Tensor, actions_embeddings = self.embed_action(actions) h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1) - print(h.shape) h = self.transformer(h) h = h.reshape(B, T + 2, self.cnn_hidden_list[-1]) diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index c3d3c896db..c60057a857 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -1,6 +1,6 @@ import torch -import numpy as np import pytest +import numpy as np from itertools import product from ding.model.template import ProcedureCloning From d1df66c60d8d5294f7b6cab3173a5d864d3b2e4e Mon Sep 17 00:00:00 2001 From: karroyan Date: Wed, 8 Feb 2023 16:32:22 +0800 Subject: [PATCH 6/6] polish style --- ding/model/template/procedure_cloning.py | 6 ++++-- ding/model/template/tests/test_procedure_cloning.py | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index b023ec615f..a86e813933 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -9,8 +9,10 @@ class Block(nn.Module): - def __init__(self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, \ - feedforward_hidden: int, n_feedforward: int) -> None: + def __init__( + self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, + feedforward_hidden: int, n_feedforward: int + ) -> None: super().__init__() self.n_att = n_att self.n_feedforward = n_feedforward diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index c60057a857..e169ec2cee 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -20,8 +20,11 @@ class TestProcedureCloning: def test_procedure_cloning(self, obs_shape, action_dim): - inputs = {'states': torch.randn(B, *obs_shape), 'goals': torch.randn(B, *obs_shape),\ - 'actions': torch.randn(B, T, action_dim)} + inputs = { + 'states': torch.randn(B, *obs_shape), + 'goals': torch.randn(B, *obs_shape), + 'actions': torch.randn(B, T, action_dim) + } model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim) print(model)