Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): add MADDPG pettingzoo example #774

Merged
merged 2 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
| 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)<br>[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
| 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)<br>[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ptz_simple_spread_maddpg_config.py -s 0 |
| 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)<br>[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
| 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)<br>[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)<br>[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
Expand Down
36 changes: 23 additions & 13 deletions ding/policy/policy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import namedtuple
from easydict import EasyDict
import gym
import gymnasium
import torch

from ding.torch_utils import to_device
Expand Down Expand Up @@ -49,26 +50,35 @@ def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:

actions = {}
for env_id in data:
if not isinstance(action_space, list):
if isinstance(action_space, list):
if 'global_state' in data[env_id].keys():
# for smac
logit = torch.ones_like(data[env_id]['action_mask'])
logit[data[env_id]['action_mask'] == 0.0] = -1e8
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
else:
# for gfootball
actions[env_id] = {
'action': torch.as_tensor(
[action_space_agent.sample() for action_space_agent in action_space]
),
'logit': torch.ones([len(action_space), action_space[0].n])
}
elif isinstance(action_space, gymnasium.spaces.Dict): # pettingzoo
actions[env_id] = {
'action': torch.as_tensor(
[action_space_agent.sample() for action_space_agent in action_space.values()]
)
}
else:
if isinstance(action_space, gym.spaces.Discrete):
action = torch.LongTensor([action_space.sample()])
elif isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action_space.sample()]
else:
action = torch.as_tensor(action_space.sample())
actions[env_id] = {'action': action}
elif 'global_state' in data[env_id].keys():
# for smac
logit = torch.ones_like(data[env_id]['action_mask'])
logit[data[env_id]['action_mask'] == 0.0] = -1e8
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
else:
# for gfootball
actions[env_id] = {
'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]),
'logit': torch.ones([len(action_space), action_space[0].n])
}
return actions

def reset(*args, **kwargs) -> None:
Expand Down
9 changes: 5 additions & 4 deletions ding/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object:
global _innest_error
if _innest_error:
argspec = inspect.getfullargspec(build_fn)
message = 'for {}(alias={})'.format(build_fn, obj_type)
message += '\nExpected args are:{}'.format(argspec)
message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys())
message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs)
message = 'Hint: for {}(alias={})'.format(build_fn, obj_type)
message += '\n\nExpected args are:\n {}\nGiven arguments keys are:\n{}\n'.format(
argspec, obj_kwargs.keys()
)
print(message)
_innest_error = False
raise e

Expand Down
2 changes: 1 addition & 1 deletion dizoo/gfootball/config/gfootball_counter_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac'),
)
gfootball_keeper_masac_default_create_config = EasyDict(gfootball_keeper_masac_default_create_config)
create_config = gfootball_keeper_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/gfootball/config/gfootball_keeper_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac'),
)
gfootball_keeper_masac_default_create_config = EasyDict(gfootball_keeper_masac_default_create_config)
create_config = gfootball_keeper_masac_default_create_config
Expand Down
81 changes: 81 additions & 0 deletions dizoo/petting_zoo/config/ptz_simple_spread_maddpg_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from easydict import EasyDict

n_agent = 3
n_landmark = n_agent
collector_env_num = 8
evaluator_env_num = 8
main_config = dict(
exp_name='ptz_simple_spread_maddpg_seed0',
env=dict(
env_family='mpe',
env_id='simple_spread_v2',
n_agent=n_agent,
n_landmark=n_landmark,
max_cycles=25,
agent_obs_only=False,
agent_specific_global_state=True,
continuous_actions=True, # ddpg only support continuous action space
act_scale=True, # necessary for continuous action space
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
stop_value=0,
),
policy=dict(
cuda=True,
multi_agent=True,
random_collect_size=5000,
model=dict(
agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
n_landmark * 2 + n_agent * (n_agent - 1) * 2,
action_shape=5,
action_space='regression',
twin_critic=False,
),
learn=dict(
update_per_collect=50,
batch_size=320,
# learning_rates
learning_rate_q=5e-4,
learning_rate_policy=5e-4,
target_theta=0.005,
discount_factor=0.99,
),
collect=dict(
n_sample=1600,
env_num=collector_env_num,
),
eval=dict(
env_num=evaluator_env_num,
evaluator=dict(eval_freq=500, ),
),
other=dict(
eps=dict(
type='linear',
start=1,
end=0.05,
decay=100000,
),
replay_buffer=dict(replay_buffer_size=int(1e6), )
),
),
)

main_config = EasyDict(main_config)
create_config = dict(
env=dict(
import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
type='petting_zoo',
),
env_manager=dict(type='subprocess'),
policy=dict(type='ddpg'),
)
create_config = EasyDict(create_config)
ptz_simple_spread_maddpg_config = main_config
ptz_simple_spread_maddpg_create_config = create_config

if __name__ == '__main__':
# or you can enter `ding -m serial_entry -c ptz_simple_spread_maddpg_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e6))
2 changes: 1 addition & 1 deletion dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
type='petting_zoo',
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete'),
policy=dict(type='discrete_sac'),
)
create_config = EasyDict(create_config)
ptz_simple_spread_masac_config = main_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_10m11m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_10m11m_masac_default_create_config = EasyDict(SMAC_10m11m_masac_default_create_config)
create_config = SMAC_10m11m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_25m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_25m_masac_default_create_config = EasyDict(SMAC_25m_masac_default_create_config)
create_config = SMAC_25m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_2c64zg_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_2c64zg_masac_default_create_config = EasyDict(SMAC_2c64zg_masac_default_create_config)
create_config = SMAC_2c64zg_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_3m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_3m_masac_default_create_config = EasyDict(SMAC_3m_masac_default_create_config)
create_config = SMAC_3m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_3s5z_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
smac_3s5z_masac_default_create_config = EasyDict(smac_3s5z_masac_default_create_config)
create_config = smac_3s5z_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
smac_3s5zvs3s6z_masac_default_create_config = EasyDict(smac_3s5zvs3s6z_masac_default_create_config)
create_config = smac_3s5zvs3s6z_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_5m6m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_5m6m_masac_default_create_config = EasyDict(SMAC_5m6m_masac_default_create_config)
create_config = SMAC_5m6m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_8m9m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_8m9m_masac_default_create_config = EasyDict(SMAC_8m9m_masac_default_create_config)
create_config = SMAC_8m9m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_MMM2_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_MMM2_masac_default_create_config = EasyDict(SMAC_MMM2_masac_default_create_config)
create_config = SMAC_MMM2_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_MMM_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
MMM_masac_default_create_config = EasyDict(MMM_masac_default_create_config)
create_config = MMM_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_corridor_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
smac_corridor_masac_default_create_config = EasyDict(smac_corridor_masac_default_create_config)
create_config = smac_corridor_masac_default_create_config
Expand Down
Loading