Skip to content

Commit

Permalink
feature(gry): add MDQN algorithm (#590)
Browse files Browse the repository at this point in the history
* draft runable verison for mdqn and config file

* fix style for mdqn

* fix style for mdqn

* update action_gap part for mdqn

* provide tau and alpha

* draft runable verison for mdqn and config file

* fix style for mdqn

* fix style for mdqn

* update action_gap part for mdqn

* provide tau and alpha

* add clipfrac to mdqn

* add unit test for mdqn td error

* provide current exp parameter

* fix bug in mdqn td loss function and polish code

* revert useless change in dqn

* update readme for mdqn

* delete wring named folder

* rename asterix folder

* provide resonable config for asterix

* fix style and unit test

* polish code under comment

* fix typo in dizoo asterix config

* fix style

* fix style

* provide is_dynamic_seed for collector env

* add unit test for mdqn in test_serial_entry with asterix

* change test for mdqn from asterix to cartpole because of platform test failed

* change is_dynamic structure because of unit test failed at test entry

* add comment for is_dynamic_seed

* add enduro and spaceinvaders mdqn config file && polish comments

* polish code under comment
  • Loading branch information
ruoyuGao authored Mar 8, 2023
1 parent 55898a3 commit 741fe40
Show file tree
Hide file tree
Showing 17 changed files with 642 additions and 36 deletions.
65 changes: 33 additions & 32 deletions README.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def serial_pipeline(
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
dynamic_seed: Optional[bool] = True,
) -> 'Policy': # noqa
"""
Overview:
Expand All @@ -36,6 +37,7 @@ def serial_pipeline(
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
- dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
Expand All @@ -53,7 +55,7 @@ def serial_pipeline(
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
Expand Down
15 changes: 15 additions & 0 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
from dizoo.classic_control.pendulum.config.pendulum_bdq_config import pendulum_bdq_config, pendulum_bdq_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_mdqn_config import cartpole_mdqn_config, cartpole_mdqn_create_config


@pytest.mark.platformtest
Expand All @@ -68,6 +69,20 @@ def test_dqn():
os.popen('rm -rf cartpole_dqn_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_mdqn():
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'cartpole_mdqn_unittest'
try:
serial_pipeline(config, seed=0, max_train_iter=1, dynamic_seed=False)
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf cartpole_mdqn_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_bdq():
Expand Down
12 changes: 12 additions & 0 deletions ding/entry/tests/test_serial_entry_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa
from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config

with open("./algo_record.log", "w+") as f:
f.write("ALGO TEST STARTS\n")
Expand Down Expand Up @@ -405,6 +406,17 @@ def test_wqmix():
f.write("28. wqmix\n")


@pytest.mark.algotest
def test_mdqn():
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
try:
serial_pipeline(config, seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
f.write("29. mdqn\n")


# @pytest.mark.algotest
def test_td3_bc():
# train expert
Expand Down
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls
from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch
from .dqn import DQNSTDIMPolicy, DQNPolicy
from .mdqn import MDQNPolicy
from .iqn import IQNPolicy
from .fqf import FQFPolicy
from .qrdqn import QRDQNPolicy
Expand Down
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_policy import CommandModePolicy

from .dqn import DQNPolicy, DQNSTDIMPolicy
from .mdqn import MDQNPolicy
from .c51 import C51Policy
from .qrdqn import QRDQNPolicy
from .iqn import IQNPolicy
Expand Down Expand Up @@ -101,6 +102,11 @@ class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('mdqn_command')
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('dqn_command')
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass
Expand Down
243 changes: 243 additions & 0 deletions ding/policy/mdqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import List, Dict, Any
import copy
import torch

from ding.torch_utils import Adam, to_device
from ding.rl_utils import m_q_1step_td_data, m_q_1step_td_error
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY

from .dqn import DQNPolicy
from .common_utils import default_preprocess_learn


@POLICY_REGISTRY.register('mdqn')
class MDQNPolicy(DQNPolicy):
"""
Overview:
Policy class of Munchausen DQN algorithm, extended by auxiliary objectives.
Paper link: https://arxiv.org/abs/2007.14430
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str mdqn | RL policy register name, refer to | This arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
| or off-policy
4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
| update priority
5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
| ``_weight`` | to correct biased update. If True,
| priority must be True.
6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
| ``factor`` [0.95, 0.999] | gamma | reward env
7 ``nstep`` int 1, | N-step reward discount sum for target
[3, 5] | q_value estimation
8 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
| ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
| valid in serial training | means more off-policy
9 | ``learn.multi`` bool False | whether to use multi gpu during
| ``_gpu``
10 | ``learn.batch_`` int 32 | The number of samples of an iteration
| ``size``
11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
| ``_rate``
12 | ``learn.target_`` int 2000 | Frequence of target network update. | Hard(assign) update
| ``update_freq``
13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
| ``done`` | calculation. | fake termination env
14 ``collect.n_sample`` int 4 | The number of training samples of a | It varies from
| call of collector. | different envs
15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
| ``_len``
16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
| 'linear'].
17 | ``other.eps.`` float 0.01 | start value of exploration rate | [0,1]
| ``start``
18 | ``other.eps.`` float 0.001 | end value of exploration rate | [0,1]
| ``end``
19 | ``other.eps.`` int 250000 | decay length of exploration | greater than 0. set
| ``decay`` | decay=250000 means
| the exploration rate
| decay from start
| value to end value
| during decay length.
20 | ``entropy_tau`` float 0.003 | the ration of entropy in TD loss
21 | ``alpha`` float 0.9 | the ration of Munchausen term to the
| TD loss
== ==================== ======== ============== ======================================== =======================
"""
config = dict(
type='mdqn',
# (bool) Whether use cuda in policy
cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
on_policy=False,
# (bool) Whether enable priority experience sample
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (float) Discount factor(gamma) for returns
discount_factor=0.97,
# (float) Entropy factor (tau) for Munchausen DQN
entropy_tau=0.03,
# (float) Discount factor (alpha) for Munchausen term
m_alpha=0.9,
# (int) The number of step for calculating target q_value
nstep=1,
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3,
# (int) How many samples in a training batch
batch_size=64,
# (float) The step size of gradient descent
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
n_sample=4,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
# (float) Epsilon start value
start=0.95,
# (float) Epsilon end value
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, ),
),
)

def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
and target model.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# Optimizer
# set eps in order to consistent with the original paper implementation
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, eps=0.0003125)

self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
self._entropy_tau = self._cfg.entropy_tau
self._m_alpha = self._cfg.m_alpha

# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
if 'target_update_freq' in self._cfg.learn:
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
elif 'target_theta' in self._cfg.learn:
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
else:
raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta")
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()

def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``, ``action_gap``, ``clip_frac``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value_current = self._target_model.forward(data['obs'])['logit']
target_q_value = self._target_model.forward(data['next_obs'])['logit']

data_m = m_q_1step_td_data(
q_value, target_q_value_current, target_q_value, data['action'], data['reward'].squeeze(0), data['done'],
data['weight']
)

loss, td_error_per_sample, action_gap, clipfrac = m_q_1step_td_error(
data_m, self._gamma, self._entropy_tau, self._m_alpha
)
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()

# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'q_value': q_value.mean().item(),
'target_q_value': target_q_value.mean().item(),
'priority': td_error_per_sample.abs().tolist(),
'action_gap': action_gap.item(),
'clip_frac': clipfrac.mean().item(),
}

def _monitor_vars_learn(self) -> List[str]:
return ['cur_lr', 'total_loss', 'q_value', 'action_gap', 'clip_frac']
3 changes: 2 additions & 1 deletion ding/rl_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .gae import gae_data, gae
from .a2c import a2c_data, a2c_error
from .coma import coma_data, coma_error
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data, td_lambda_error,\
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, \
q_1step_td_error, m_q_1step_td_data, m_q_1step_td_error, td_lambda_data, td_lambda_error,\
q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\
Expand Down
Loading

0 comments on commit 741fe40

Please sign in to comment.