Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev distar collector merge policy #411

Merged
merged 23 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9c0d6f3
polish(nyz): polish parse_new_game and add transform_obs
PaParaZz1 Jun 23, 2022
d6ef348
update train iter
hiha3456 Jun 28, 2022
0cb987f
change logic of update train_iter
hiha3456 Jun 28, 2022
7012966
add check of main player
hiha3456 Jun 28, 2022
7912613
fix bug
hiha3456 Jun 28, 2022
2b51afc
fix bug
hiha3456 Jun 28, 2022
83d93d0
test(nyz): add naive distar policy collect test
PaParaZz1 Jun 28, 2022
d96f953
Merge branch 'dev-distar' into dev-distar-collector
hiha3456 Jun 29, 2022
8408399
merge dev-distar-nyz
hiha3456 Jun 29, 2022
7b5b233
to run in k8s
hiha3456 Jun 29, 2022
0b985d2
change num workers
hiha3456 Jun 29, 2022
57d3a43
to run real policy forward_collect
hiha3456 Jun 29, 2022
584fd82
print exception
hiha3456 Jun 30, 2022
df21a9a
reformat test
hiha3456 Jun 30, 2022
72803ef
fix bug
hiha3456 Jun 30, 2022
e8e7551
tools to do serialization and test if two objects same
hiha3456 Jul 6, 2022
2da4010
changes in the model to correctly make actions using pretrained model
hiha3456 Jul 6, 2022
a0efe1c
changes to run the test using pretrained model
hiha3456 Jul 6, 2022
a3c9e39
tests to test the performance againist bot using pretrained mdoel
hiha3456 Jul 6, 2022
43f86a1
changes in the policy(agent) to correctly make actions using pretrain…
hiha3456 Jul 6, 2022
a99c8bb
move GLU and build_activation in action_type_head.py to ding/torch_ut…
hiha3456 Jul 6, 2022
a11e146
change default value of build_activation to False
hiha3456 Jul 10, 2022
ca651b5
add util to change ia's model
hiha3456 Jul 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@
from ding.envs import BaseEnvManager
from ding.utils import log_every_sec
from ding.framework import task
from ding.framework.middleware.functional import PlayerModelInfo
from .functional import inferencer, rolloutor, TransitionList, BattleTransitionList, \
battle_inferencer, battle_rolloutor

if TYPE_CHECKING:
from ding.framework import OnlineRLContext, BattleContext

WAIT_MODEL_TIME = 60
WAIT_MODEL_TIME = 600000


class BattleStepCollector:

def __init__(
self, cfg: EasyDict, env: BaseEnvManager, unroll_len: int, model_dict: Dict, model_time_dict: Dict,
self, cfg: EasyDict, env: BaseEnvManager, unroll_len: int, model_dict: Dict, model_info_dict: Dict,
all_policies: Dict, agent_num: int
):
self.cfg = cfg
Expand All @@ -31,15 +32,17 @@ def __init__(
self.total_envstep_count = 0
self.unroll_len = unroll_len
self.model_dict = model_dict
self.model_time_dict = model_time_dict
self.model_info_dict = model_info_dict
self.all_policies = all_policies
self.agent_num = agent_num

self._battle_inferencer = task.wrap(battle_inferencer(self.cfg, self.env))
self._transitions_list = [
BattleTransitionList(self.env.env_num, self.unroll_len) for _ in range(self.agent_num)
]
self._battle_rolloutor = task.wrap(battle_rolloutor(self.cfg, self.env, self._transitions_list))
self._battle_rolloutor = task.wrap(
battle_rolloutor(self.cfg, self.env, self._transitions_list, self.model_info_dict)
)

def __del__(self) -> None:
"""
Expand All @@ -54,21 +57,23 @@ def __del__(self) -> None:

def _update_policies(self, player_id_list) -> None:
for player_id in player_id_list:
# for this player, actor didn't recieve any new model, use initial model instead.
if self.model_time_dict.get(player_id) is None:
self.model_time_dict[player_id] = time.time()
# for this player, in the beginning of actor's lifetime, actor didn't recieve any new model, use initial model instead.
if self.model_info_dict.get(player_id) is None:
self.model_info_dict[player_id] = PlayerModelInfo(
get_new_model_time=time.time(), update_new_model_time=None
)

while True:
time_now = time.time()
time_list = [time_now - self.model_time_dict[player_id] for player_id in player_id_list]
time_list = [time_now - self.model_info_dict[player_id].get_new_model_time for player_id in player_id_list]
if any(x >= WAIT_MODEL_TIME for x in time_list):
for player_id in player_id_list:
if time_now - self.model_time_dict[player_id] >= WAIT_MODEL_TIME:
for index, player_id in enumerate(player_id_list):
if time_list[index] >= WAIT_MODEL_TIME:
#TODO: log_every_sec can only print the first model that not updated
log_every_sec(
logging.WARNING, 5,
'In actor {}, model for {} is not updated for {} senconds, and need new model'.format(
task.router.node_id, player_id, time_now - self.model_time_dict[player_id]
task.router.node_id, player_id, time_list[index]
)
)
time.sleep(1)
Expand All @@ -84,6 +89,8 @@ def _update_policies(self, player_id_list) -> None:
assert policy, "for player{}, policy should have been initialized already"
# update policy model
policy.load_state_dict(learner_model.state_dict)
self.model_info_dict[player_id].update_new_model_time = time.time()
self.model_info_dict[player_id].update_train_iter = learner_model.train_iter
self.model_dict[player_id] = None

def __call__(self, ctx: "BattleContext") -> None:
Expand All @@ -103,6 +110,9 @@ def __call__(self, ctx: "BattleContext") -> None:
while True:
if self.env.closed:
self.env.launch()
# TODO(zms): only runnable when 1 actor has exactly one env, need to write more general
for policy_id, policy in enumerate(ctx.current_policies):
policy.reset(self.env.ready_obs[0][policy_id])
self._update_policies(ctx.player_id_list)
self._battle_inferencer(ctx)
self._battle_rolloutor(ctx)
Expand Down Expand Up @@ -156,7 +166,6 @@ def __call__(self, ctx: "BattleContext") -> None:
# self.env.close()

# def _update_policies(self, player_id_list) -> None:
# # TODO(zms): update train_iter, update train_iter and player_id inside policy is a good idea
# for player_id in player_id_list:
# if self.model_dict.get(player_id) is None:
# continue
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
from .explorer import eps_greedy_handler, eps_greedy_masker
from .advantage_estimator import gae_estimator
from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer
from .actor_data import ActorData, ActorDataMeta, ActorEnvTrajectories
from .actor_data import ActorData, ActorDataMeta, ActorEnvTrajectories, PlayerModelInfo
7 changes: 7 additions & 0 deletions ding/framework/middleware/functional/actor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@ class ActorEnvTrajectories:
class ActorData:
meta: ActorDataMeta
train_data: List[ActorEnvTrajectories] = field(default_factory=[])


@dataclass
class PlayerModelInfo:
get_new_model_time: float
update_new_model_time: float
update_train_iter: int = 0
51 changes: 30 additions & 21 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Callable, List, Tuple, Any
from typing import TYPE_CHECKING, Optional, Callable, List, Tuple, Any, Dict
from easydict import EasyDict
from functools import reduce
import treetensor.torch as ttorch
Expand Down Expand Up @@ -128,11 +128,14 @@ def _cut_trajectory_from_episode(self, episode: list) -> List[List]:

def clear_newest_episode(self, env_id: int) -> None:
# Use it when env.step raise some error
newest_episode = self._transitions[env_id].pop()
len_newest_episode = len(newest_episode)
newest_episode.clear()
self._done_episode[env_id].pop()
return len_newest_episode
if len(self._transitions[env_id]) > 0:
newest_episode = self._transitions[env_id].pop()
len_newest_episode = len(newest_episode)
newest_episode.clear()
self._done_episode[env_id].pop()
return len_newest_episode
else:
return 0

def append(self, env_id: int, transition: Any) -> bool:
# If previous episode is done, we create a new episode
Expand Down Expand Up @@ -275,24 +278,26 @@ def _battle_inferencer(ctx: "BattleContext"):
return _battle_inferencer


def battle_rolloutor(cfg: EasyDict, env: BaseEnvManager, transitions_list: List):
def battle_rolloutor(cfg: EasyDict, env: BaseEnvManager, transitions_list: List, model_info_dict: Dict):

def _battle_rolloutor(ctx: "BattleContext"):
timesteps = env.step(ctx.actions)
ctx.total_envstep_count += len(timesteps)
ctx.env_step += len(timesteps)
for env_id, timestep in timesteps.items():
for policy_id, _ in enumerate(ctx.current_policies):
for policy_id, policy in enumerate(ctx.current_policies):
policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep]
policy_timestep = type(timestep)(*policy_timestep_data)
transition = ctx.current_policies[policy_id].process_transition(
transition = policy.process_transition(
ctx.obs[policy_id][env_id], ctx.inference_output[policy_id][env_id], policy_timestep
)
transition = ttorch.as_tensor(transition)
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
transition.collect_train_iter = ttorch.as_tensor(
[model_info_dict[ctx.player_id_list[policy_id]].update_train_iter]
)
transitions_list[policy_id].append(env_id, transition)
if timestep.done:
ctx.current_policies[policy_id].reset([env_id])
policy.reset([env_id])
ctx.episode_info[policy_id].append(timestep.info[policy_id])

if timestep.done:
Expand Down Expand Up @@ -329,7 +334,7 @@ def _battle_inferencer(ctx: "BattleContext"):
return _battle_inferencer


def battle_rolloutor_for_distar(cfg: EasyDict, env: BaseEnvManager, transitions_list: List):
def battle_rolloutor_for_distar(cfg: EasyDict, env: BaseEnvManager, transitions_list: List, model_info_dict: Dict):

def _battle_rolloutor(ctx: "BattleContext"):
timesteps = env.step(ctx.actions)
Expand All @@ -348,24 +353,28 @@ def _battle_rolloutor(ctx: "BattleContext"):
# ctx.env_step -= transitions_list[0].length(env_id)

# 1st case when env step has bug and need to reset.
for policy_id, _ in enumerate(ctx.current_policies):

# TODO(zms): if it is first step of the episode, do not delete the last episode in the TransitionList
for policy_id, policy in enumerate(ctx.current_policies):
transitions_list[policy_id].clear_newest_episode(env_id)
ctx.current_policies[policy_id].reset([env_id])
policy.reset(env.ready_obs[0][policy_id])
continue

append_succeed = True
for policy_id, _ in enumerate(ctx.current_policies):
transition = ctx.current_policies[policy_id].process_transition(timestep)
episode_long_enough = True
for policy_id, policy in enumerate(ctx.current_policies):
transition = policy.process_transition(timestep)
transition = EasyDict(transition)
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
transition.collect_train_iter = ttorch.as_tensor(
[model_info_dict[ctx.player_id_list[policy_id]].update_train_iter]
)

# 2nd case when the number of transitions in one of all the episodes is shorter than unroll_len
append_succeed = append_succeed and transitions_list[policy_id].append(env_id, transition)
episode_long_enough = episode_long_enough and transitions_list[policy_id].append(env_id, transition)
if timestep.done:
ctx.current_policies[policy_id].reset([env_id])
policy.reset(env.ready_obs[0][policy_id])
ctx.episode_info[policy_id].append(timestep.info[policy_id])

if not append_succeed:
if not episode_long_enough:
for policy_id, _ in enumerate(ctx.current_policies):
transitions_list[policy_id].clear_newest_episode(env_id)
ctx.episode_info[policy_id].pop()
Expand Down
77 changes: 44 additions & 33 deletions ding/framework/middleware/league_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from easydict import EasyDict
import time
from ditk import logging
import torch

from ding.policy import Policy
from ding.framework import task, EventEnum
from ding.framework.middleware import BattleStepCollector
from ding.framework.middleware.functional import ActorData, ActorDataMeta
from ding.framework.middleware.functional import ActorData, ActorDataMeta, PlayerModelInfo
from ding.league.player import PlayerMeta
from ding.utils.sparse_logging import log_every_sec

Expand All @@ -33,7 +34,7 @@ def __init__(self, cfg: EasyDict, env_fn: Callable, policy_fn: Callable):
self.job_queue = queue.Queue()
self.model_dict = {}
self.model_dict_lock = Lock()
self.model_time_dict = {}
self.model_info_dict = {}

self.agent_num = 2

Expand All @@ -50,7 +51,13 @@ def _on_learner_model(self, learner_model: "LearnerModel"):
)
with self.model_dict_lock:
self.model_dict[learner_model.player_id] = learner_model
self.model_time_dict[learner_model.player_id] = time.time()
if self.model_info_dict.get(learner_model.player_id):
self.model_info_dict[learner_model.player_id].get_new_model_time = time.time()
self.model_info_dict[learner_model.player_id].update_new_model_time = None
else:
self.model_info_dict[learner_model.player_id] = PlayerModelInfo(
get_new_model_time=time.time(), update_new_model_time=None
)

def _on_league_job(self, job: "Job"):
"""
Expand All @@ -65,7 +72,7 @@ def _get_collector(self, player_id: str):
env = self.env_fn()
collector = task.wrap(
BattleStepCollector(
cfg.policy.collect.collector, env, self.unroll_len, self.model_dict, self.model_time_dict,
cfg.policy.collect.collector, env, self.unroll_len, self.model_dict, self.model_info_dict,
self.all_policies, self.agent_num
)
)
Expand Down Expand Up @@ -106,11 +113,11 @@ def _get_current_policies(self, job):
assert main_player, "[Actor {}] can not find active player.".format(task.router.node_id)

if current_policies is not None:
assert len(current_policies) > 1, "[Actor {}] battle collector needs more than 1 policies".format(
task.router.node_id
)
for p in current_policies:
p.reset()
#TODO(zms): make it more general, should we have the restriction of 1 policies
# assert len(current_policies) > 1, "[Actor {}] battle collector needs more than 1 policies".format(
# task.router.node_id
# )
pass
else:
raise RuntimeError('[Actor {}] current_policies should not be None'.format(task.router.node_id))

Expand All @@ -124,13 +131,20 @@ def __call__(self, ctx: "BattleContext"):
log_every_sec(
logging.INFO, 5, '[Actor {}] job of player {} begins.'.format(task.router.node_id, job.launch_player)
)

# TODO(zms): when get job, update the policies to the checkpoint in job
ctx.player_id_list = [player.player_id for player in job.players]
main_player_idx = [idx for idx, player in enumerate(job.players) if player.player_id == job.launch_player]
self.agent_num = len(job.players)
collector = self._get_collector(job.launch_player)

main_player, ctx.current_policies = self._get_current_policies(job)

#TODO(zms): only for test pretrained model
rl_model = torch.load('./rl_model.pth')
for policy in ctx.current_policies:
policy.load_state_dict(rl_model)
print('load state_dict okay')

ctx.n_episode = self.cfg.policy.collect.n_episode
assert ctx.n_episode >= self.env_num, "[Actor {}] Please make sure n_episode >= env_num".format(
task.router.node_id
Expand All @@ -139,36 +153,31 @@ def __call__(self, ctx: "BattleContext"):
ctx.n_episode = self.cfg.policy.collect.n_episode
assert ctx.n_episode >= self.env_num, "Please make sure n_episode >= env_num"

ctx.train_iter = main_player.total_agent_step
ctx.episode_info = [[] for _ in range(self.agent_num)]

while True:
time_begin = time.time()
old_envstep = ctx.total_envstep_count
collector(ctx)

# ctx.trajectories_list[0] for policy_id 0
# ctx.trajectories_list[0][0] for first env
if len(ctx.trajectories_list[0]) > 0:
for traj in ctx.trajectories_list[0][0].trajectories:
assert len(traj) == self.unroll_len + 1
if ctx.job_finish is True:
logging.info('[Actor {}] finish current job !'.format(task.router.node_id))
assert len(ctx.trajectories_list[0][0].trajectories) > 0
# TODO(zms): 判断是不是main_player
if not job.is_eval and len(ctx.trajectories_list[0]) > 0:
trajectories = ctx.trajectories_list[0]
log_every_sec(
logging.INFO, 5, '[Actor {}] send {} trajectories.'.format(task.router.node_id, len(trajectories))
)
meta_data = ActorDataMeta(
player_total_env_step=ctx.total_envstep_count,
actor_id=task.router.node_id,
send_wall_time=time.time()
)
actor_data = ActorData(meta=meta_data, train_data=trajectories)
task.emit(EventEnum.ACTOR_SEND_DATA.format(player=job.launch_player), actor_data)
log_every_sec(logging.INFO, 5, '[Actor {}] send data\n'.format(task.router.node_id))

for idx in main_player_idx:
if not job.is_eval and len(ctx.trajectories_list[idx]) > 0:
trajectories = ctx.trajectories_list[idx]
log_every_sec(
logging.INFO, 5,
'[Actor {}] send {} trajectories.'.format(task.router.node_id, len(trajectories))
)
meta_data = ActorDataMeta(
player_total_env_step=ctx.total_envstep_count,
actor_id=task.router.node_id,
send_wall_time=time.time()
)
actor_data = ActorData(meta=meta_data, train_data=trajectories)
task.emit(EventEnum.ACTOR_SEND_DATA.format(player=job.launch_player), actor_data)
log_every_sec(logging.INFO, 5, '[Actor {}] send data\n'.format(task.router.node_id))

ctx.trajectories_list = []
time_end = time.time()
Expand All @@ -186,7 +195,10 @@ def __call__(self, ctx: "BattleContext"):
)

if ctx.job_finish is True:
job.result = [e['result'] for e in ctx.episode_info[0]]
job.result = []
for idx in main_player_idx:
for e in ctx.episode_info[idx]:
job.result.append(e['result'])
task.emit(EventEnum.ACTOR_FINISH_JOB, job)
ctx.episode_info = [[] for _ in range(self.agent_num)]
logging.info('[Actor {}] job finish, send job\n'.format(task.router.node_id))
Expand Down Expand Up @@ -299,7 +311,6 @@ def __call__(self, ctx: "BattleContext"):
# task.router.node_id
# )

# ctx.train_iter = main_player.total_agent_step
# ctx.episode_info = [[] for _ in range(self.agent_num)]
# ctx.remain_episode = ctx.n_episode
# while True:
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/league_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class LeagueLearnerCommunicator:

def __init__(self, cfg: dict, policy: "Policy", player: "ActivePlayer") -> None:
self.cfg = cfg
self._cache = deque(maxlen=1000)
self._cache = deque(maxlen=50)
self.player = player
self.player_id = player.player_id
self.policy = policy
Expand Down
Loading