Skip to content

Commit

Permalink
feature(nyz): add policy gradient algo implementation (#544)
Browse files Browse the repository at this point in the history
* feature(nyz): add policy gradient algo implementation

* demo(nyz): add lunarlander pg demo

* style(nyz): add pg link in readme

* fix(nyz): fix config conflict with data generation

* fix(nyz): fix action space error type in model
  • Loading branch information
PaParaZz1 authored Nov 24, 2022
1 parent bceb05b commit 4c607d4
Show file tree
Hide file tree
Showing 19 changed files with 530 additions and 69 deletions.
83 changes: 42 additions & 41 deletions README.md

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def compile_collector_config(
other=dict(replay_buffer=dict()),
)
policy_config_template = EasyDict(policy_config_template)
env_config_template = dict(manager=dict(), )
env_config_template = dict(manager=dict(), stop_value=int(1e10))
env_config_template = EasyDict(env_config_template)


Expand Down Expand Up @@ -449,11 +449,12 @@ def compile_config(
default_config['reward_model'] = reward_model_config
if len(world_model_config) > 0:
default_config['world_model'] = world_model_config
stop_value_flag = 'stop_value' in cfg.env
cfg = deep_merge_dicts(default_config, cfg)
cfg.seed = seed
# check important key in config
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
if 'stop_value' in cfg.env: # data generation task doesn't need these fields
if stop_value_flag: # data generation task doesn't need these fields
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
if 'exp_name' not in cfg:
Expand Down
12 changes: 12 additions & 0 deletions ding/entry/tests/test_serial_entry_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config
from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa
Expand Down Expand Up @@ -170,6 +171,17 @@ def test_r2d2():
f.write("11. r2d2\n")


@pytest.mark.algotest
def test_pg():
config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)]
try:
serial_pipeline_onpolicy(config, seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
f.write("12. pg\n")


# @pytest.mark.algotest
def test_atoc():
config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)]
Expand Down
11 changes: 11 additions & 0 deletions ding/entry/tests/test_serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy

from ding.entry import serial_pipeline_onpolicy
from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppopg_config import cartpole_ppopg_config, cartpole_ppopg_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
Expand All @@ -12,6 +13,16 @@
from dizoo.classic_control.cartpole.config.cartpole_ppo_stdim_config import cartpole_ppo_stdim_config, cartpole_ppo_stdim_create_config # noqa


@pytest.mark.platformtest
@pytest.mark.unittest
def test_pg():
config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)]
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_a2c():
Expand Down
3 changes: 2 additions & 1 deletion ding/model/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, head_cls_map
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, head_cls_map, \
independent_normal_dist
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
11 changes: 10 additions & 1 deletion ding/model/common/head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict
from typing import Optional, Dict, Union, List

import math
import torch
Expand Down Expand Up @@ -1116,6 +1116,15 @@ def forward(self, x: torch.Tensor) -> Dict:
return lists_to_dicts([m(x) for m in self.pred])


def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Distribution:
if isinstance(logits, (list, tuple)):
return Independent(Normal(*logits), 1)
elif isinstance(logits, dict):
return Independent(Normal(logits['mu'], logits['sigma']), 1)
else:
raise TypeError("invalid logits type: {}".format(type(logits)))


head_cls_map = {
# discrete
'discrete': DiscreteHead,
Expand Down
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .pdqn import PDQN
from .vac import VAC
from .bc import DiscreteBC, ContinuousBC
from .pg import PG
# algorithm-specific
from .ppg import PPG
from .qmix import Mixer, QMix
Expand Down
67 changes: 67 additions & 0 deletions ding/model/template/pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Union, Optional, Dict, Callable, List
import torch
import torch.nn as nn
from easydict import EasyDict

from ding.torch_utils import get_lstm
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \
MultiHead, RegressionHead, ReparameterizationHead, independent_normal_dist


@MODEL_REGISTRY.register('pg')
class PG(nn.Module):

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
action_space: str = 'discrete',
encoder_hidden_size_list: SequenceType = [128, 128, 64],
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
super(PG, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]
# FC Encoder
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
# Conv Encoder
elif len(obs_shape) == 3:
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape)
)
self.action_space = action_space
# Head
if self.action_space == 'discrete':
self.head = DiscreteHead(
head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
)
elif self.action_space == 'continuous':
self.head = ReparameterizationHead(
head_hidden_size,
action_shape,
head_layer_num,
activation=activation,
norm_type=norm_type,
sigma_type='independent'
)
else:
raise KeyError("not support action space: {}".format(self.action_space))

def forward(self, x: torch.Tensor) -> Dict:
x = self.encoder(x)
x = self.head(x)
if self.action_space == 'discrete':
x['dist'] = torch.distributions.Categorical(logits=x['logit'])
elif self.action_space == 'continuous':
x = {'logit': {'mu': x['mu'], 'sigma': x['sigma']}}
x['dist'] = independent_normal_dist(x['logit'])
return x
61 changes: 61 additions & 0 deletions ding/model/template/tests/test_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import numpy as np
import pytest
from itertools import product

from ding.model.template import PG
from ding.torch_utils import is_differentiable
from ding.utils import squeeze

B = 4


@pytest.mark.unittest
class TestDiscretePG:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

def test_discrete_pg(self):
obs_shape = (4, 84, 84)
action_shape = 5
model = PG(
obs_shape,
action_shape,
)
inputs = torch.randn(B, 4, 84, 84)

outputs = model(inputs)
assert isinstance(outputs, dict)
assert outputs['logit'].shape == (B, action_shape)
assert outputs['dist'].sample().shape == (B, )
self.output_check(model, outputs['logit'])

def test_continuous_pg(self):
N = 32
action_shape = (6, )
inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
model = PG(
obs_shape=(N, ),
action_shape=action_shape,
action_space='continuous',
)
# compute_action
print(model)
outputs = model(inputs['obs'])
assert isinstance(outputs, dict)
dist = outputs['dist']
action = dist.sample()
assert action.shape == (B, *action_shape)

logit = outputs['logit']
mu, sigma = logit['mu'], logit['sigma']
assert mu.shape == (B, *action_shape)
assert sigma.shape == (B, *action_shape)
is_differentiable(mu.sum() + sigma.sum(), model)
3 changes: 2 additions & 1 deletion ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from .d4pg import D4PGPolicy
from .td3 import TD3Policy
from .td3_vae import TD3VAEPolicy

from .td3_bc import TD3BCPolicy

from .pg import PGPolicy
from .a2c import A2CPolicy
from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy
from .sac import SACPolicy, SACDiscretePolicy, SQILSACPolicy
Expand Down
8 changes: 6 additions & 2 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,14 @@ def _monitor_vars_learn(self) -> List[str]:
return ['cur_lr', 'total_loss']

def _state_dict_learn(self) -> Dict[str, Any]:
return {'model': self._learn_model.state_dict()}
return {
'model': self._learn_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
}

def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._learn_model.load_state_dict(state_dict['model'], strict=True)
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer.load_state_dict(state_dict['optimizer'])

def _get_batch_size(self) -> Union[int, Dict[str, int]]:
return self._cfg.learn.batch_size
Expand Down
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy
from .offppo_collect_traj import OffPPOCollectTrajPolicy
from .ppg import PPGPolicy, PPGOffPolicy
from .pg import PGPolicy
from .a2c import A2CPolicy
from .impala import IMPALAPolicy
from .ngu import NGUPolicy
Expand Down Expand Up @@ -189,6 +190,11 @@ class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandMo
pass


@POLICY_REGISTRY.register('pg_command')
class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('a2c_command')
class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy):
pass
Expand Down
Loading

0 comments on commit 4c607d4

Please sign in to comment.