Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(lxy): add procedure cloning model #573

Merged
merged 15 commits into from
Feb 8, 2023
Merged
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .madqn import MADQN
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloning
120 changes: 120 additions & 0 deletions ding/model/template/procedure_cloning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Optional
from ding.utils import MODEL_REGISTRY
from typing import Tuple
from ding.torch_utils.network.transformer import Attention
from ..common import FCEncoder, ConvEncoder
from ding.torch_utils.network.nn_module import fc_block, build_normalization
from ding.utils import SequenceType
import torch
import torch.nn as nn


class Block(nn.Module):

def __init__(self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, \
feedforward_hidden: int, n_feedforward: int) -> None:
super().__init__()
self.n_att = n_att
self.n_feedforward = n_feedforward
self.attention_layer = []

self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att
self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
for i in range(n_att - 1):
self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))

self.att_drop = nn.Dropout(drop_p)

self.fc_blocks = []
self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
for i in range(n_feedforward - 1):
self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward)
self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T)

def forward(self, x: torch.Tensor):
for i in range(self.n_att):
x = self.att_drop(self.attention_layer[i](x, self.mask))
x = self.norm_layer[i](x)
for i in range(self.n_feedforward):
x = self.fc_blocks[i](x)
x = self.norm_layer[i + self.n_att](x)
return x


@MODEL_REGISTRY.register('pc')
class ProcedureCloning(nn.Module):

def __init__(
self,
obs_shape: SequenceType,
action_dim: int,
cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256],
cnn_activation: Optional[nn.Module] = nn.ReLU(),
cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3],
cnn_stride: SequenceType = [1, 1, 1, 1, 1],
cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'],
mlp_hidden_list: SequenceType = [256, 256],
mlp_activation: Optional[nn.Module] = nn.ReLU(),
att_heads: int = 8,
att_hidden: int = 128,
n_att: int = 4,
n_feedforward: int = 2,
feedforward_hidden: int = 256,
drop_p: float = 0.5,
augment: bool = True,
max_T: int = 17
) -> None:
super().__init__()

#Conv Encoder
self.embed_state = ConvEncoder(
obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding
)
self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation)

self.cnn_hidden_list = cnn_hidden_list
self.augment = augment

assert cnn_hidden_list[-1] == mlp_hidden_list[-1]
layers = []
for i in range(n_att):
if i == 0:
layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
else:
layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
layers.append(build_normalization('LN')(att_hidden))
for i in range(n_feedforward):
if i == 0:
layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
else:
layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
self.layernorm2 = build_normalization('LN')(feedforward_hidden)

self.transformer = Block(
cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward
)

self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1])
self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim)

def forward(self, states: torch.Tensor, goals: torch.Tensor,
actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

B, T, _ = actions.shape

# shape: (B, h_dim)
state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1])
goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1])
# shape: (B, context_len, h_dim)
actions_embeddings = self.embed_action(actions)

h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1)
print(h.shape)
h = self.transformer(h)
h = h.reshape(B, T + 2, self.cnn_hidden_list[-1])

goal_preds = self.predict_goal(h[:, 0, :])
action_preds = self.predict_action(h[:, 1:, :])

return goal_preds, action_preds
31 changes: 31 additions & 0 deletions ding/model/template/tests/test_procedure_cloning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import numpy as np
import pytest
from itertools import product

from ding.model.template import ProcedureCloning
from ding.torch_utils import is_differentiable
from ding.utils import squeeze

B = 4
T = 15
obs_shape = [(64, 64, 3)]
action_dim = [9]
obs_embeddings = 256
args = list(product(*[obs_shape, action_dim]))


@pytest.mark.unittest
@pytest.mark.parametrize('obs_shape, action_dim', args)
class TestProcedureCloning:

def test_procedure_cloning(self, obs_shape, action_dim):
inputs = {'states': torch.randn(B, *obs_shape), 'goals': torch.randn(B, *obs_shape),\
'actions': torch.randn(B, T, action_dim)}
model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim)

print(model)

goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
assert goal_preds.shape == (B, obs_embeddings)
assert action_preds.shape == (B, T + 1, action_dim)