Skip to content

Commit

Permalink
Merge pull request #350 from lixl-st/dev-league-lxl
Browse files Browse the repository at this point in the history
feature(nyz, lxl): merge from dev-distar-learn
  • Loading branch information
hiha3456 authored Jun 13, 2022
2 parents e3c95ff + 7ad26ef commit 400a19d
Show file tree
Hide file tree
Showing 48 changed files with 5,547 additions and 121 deletions.
2 changes: 2 additions & 0 deletions ding/framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(self, *args, **kwargs) -> None:

self.keep('train_iter', 'last_eval_iter')


class BattleContext(Context):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.__dict__ = self
1 change: 1 addition & 0 deletions ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .ckpt_handler import CkptSaver
from .league_actor import LeagueActor
from .league_coordinator import LeagueCoordinator
from .league_learner import LeagueLearner
15 changes: 7 additions & 8 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from distutils.log import info
from easydict import EasyDict
from ding import policy
from ding.policy import Policy, get_random_policy
from ding.envs import BaseEnvManager
from ding.framework import task, EventEnum
Expand All @@ -11,9 +12,12 @@

from ding.worker.collector.base_serial_collector import CachePool


class BattleCollector:

def __init__(self, cfg: EasyDict, env: BaseEnvManager, n_rollout_samples: int, model_dict: Dict, all_policies: Dict):
def __init__(
self, cfg: EasyDict, env: BaseEnvManager, n_rollout_samples: int, model_dict: Dict, all_policies: Dict
):
self.cfg = cfg
self.end_flag = False
# self._reset(env)
Expand All @@ -36,7 +40,6 @@ def __init__(self, cfg: EasyDict, env: BaseEnvManager, n_rollout_samples: int, m
self._battle_rolloutor = task.wrap(battle_rolloutor(self.cfg, self.env, self.obs_pool, self.policy_output_pool))
self._job_data_sender = task.wrap(job_data_sender(self.streaming_sampling_flag, self.n_rollout_samples))


def __del__(self) -> None:
"""
Overview:
Expand All @@ -47,9 +50,9 @@ def __del__(self) -> None:
return
self.end_flag = True
self.env.close()

def _update_policies(self, job) -> None:
job_player_id_list = [player.player_id for player in job.players]
job_player_id_list = [player.player_id for player in job.players]

for player_id in job_player_id_list:
if self.model_dict.get(player_id) is None:
Expand All @@ -62,8 +65,6 @@ def _update_policies(self, job) -> None:
policy.load_state_dict(learner_model.state_dict)
self.model_dict[player_id] = None



def __call__(self, ctx: "BattleContext") -> None:
"""
Input of ctx:
Expand Down Expand Up @@ -107,8 +108,6 @@ def __call__(self, ctx: "BattleContext") -> None:
break




class StepCollector:
"""
Overview:
Expand Down
3 changes: 2 additions & 1 deletion ding/framework/middleware/functional/actor_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any
from dataclasses import dataclass


@dataclass
class ActorData:
train_data: Any
env_step: int = 0
env_step: int = 0
6 changes: 3 additions & 3 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ def _policy_resetter(ctx: "BattleContext"):

return _policy_resetter


def job_data_sender(streaming_sampling_flag: bool, n_rollout_samples: int):

def _job_data_sender(ctx: "BattleContext"):
if not ctx.job.is_eval and streaming_sampling_flag is True and len(ctx.train_data[0]) >= n_rollout_samples:
actor_data = ActorData(env_step=ctx.envstep, train_data=ctx.train_data[0])
Expand All @@ -178,9 +179,8 @@ def _job_data_sender(ctx: "BattleContext"):
actor_data = ActorData(env_step=ctx.envstep, train_data=ctx.train_data[0])
task.emit(EventEnum.ACTOR_SEND_DATA.format(player=ctx.job.launch_player), actor_data)
ctx.train_data = [[] for _ in range(ctx.agent_num)]

return _job_data_sender

return _job_data_sender


def battle_inferencer(cfg: EasyDict, env: BaseEnvManager, obs_pool: CachePool, policy_output_pool: CachePool):
Expand Down
20 changes: 13 additions & 7 deletions ding/framework/middleware/league_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from threading import Lock
import queue


class LeagueActor:

def __init__(self, cfg: EasyDict, env_fn: Callable, policy_fn: Callable):
Expand All @@ -36,11 +37,13 @@ def __init__(self, cfg: EasyDict, env_fn: Callable, policy_fn: Callable):
self.job_queue = queue.Queue()
self.model_dict = {}
self.model_dict_lock = Lock()
self._step = 0

def _on_learner_model(self, learner_model: "LearnerModel"):
"""
If get newest learner model, put it inside model_queue.
"""
print("receive model from learner")
with self.model_dict_lock:
self.model_dict[learner_model.player_id] = learner_model

Expand All @@ -55,7 +58,11 @@ def _get_collector(self, player_id: str):
return self._collectors.get(player_id)
cfg = self.cfg
env = self.env_fn()
collector = task.wrap(BattleCollector(cfg.policy.collect.collector, env, self.n_rollout_samples, self.model_dict, self.all_policies))
collector = task.wrap(
BattleCollector(
cfg.policy.collect.collector, env, self.n_rollout_samples, self.model_dict, self.all_policies
)
)
self._collectors[player_id] = collector
return collector

Expand All @@ -79,8 +86,8 @@ def _get_job(self):
job = self.job_queue.get(timeout=10)
except queue.Empty:
logging.warning("For actor_{}, no Job get from coordinator".format(task.router.node_id))
return job

return job

def _get_current_policies(self, job):
current_policies = []
Expand All @@ -91,13 +98,12 @@ def _get_current_policies(self, job):
main_player = player
return main_player, current_policies


def __call__(self, ctx: "BattleContext"):

ctx.job = self._get_job()
if ctx.job is None:
return

collector = self._get_collector(ctx.job.launch_player)

main_player, ctx.current_policies = self._get_current_policies(ctx.job)
Expand All @@ -110,5 +116,5 @@ def __call__(self, ctx: "BattleContext"):
ctx.policy_kwargs = None

collector(ctx)

logging.info("{} Step: {}".format(self.__class__, self._step))
self._step += 1
13 changes: 10 additions & 3 deletions ding/framework/middleware/league_coordinator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections import defaultdict
from time import sleep
from threading import Lock
from dataclasses import dataclass
from typing import TYPE_CHECKING
from ding.framework import task, EventEnum
import logging

if TYPE_CHECKING:
from ding.framework import Task, Context
from ding.league import BaseLeague
from ding.league.v2 import BaseLeague


class LeagueCoordinator:
Expand All @@ -16,7 +18,8 @@ def __init__(self, league: "BaseLeague") -> None:
self._lock = Lock()
self._total_send_jobs = 0
self._eval_frequency = 10

self._step = 0

task.on(EventEnum.ACTOR_GREETING, self._on_actor_greeting)
task.on(EventEnum.LEARNER_SEND_META, self._on_learner_meta)
task.on(EventEnum.ACTOR_FINISH_JOB, self._on_actor_job)
Expand All @@ -34,11 +37,15 @@ def _on_actor_greeting(self, actor_id):
task.emit(EventEnum.COORDINATOR_DISPATCH_ACTOR_JOB.format(actor_id=actor_id), job)

def _on_learner_meta(self, player_meta: "PlayerMeta"):
print("on_learner_meta {}".format(player_meta))
self.league.update_active_player(player_meta)
self.league.create_historical_player(player_meta)

def _on_actor_job(self, job: "Job"):
print("on_actor_job {}".format(job.launch_player)) # right
self.league.update_payoff(job)

def __call__(self, ctx: "Context") -> None:
sleep(1)
logging.info("{} Step: {}".format(self.__class__, self._step))
self._step += 1
76 changes: 76 additions & 0 deletions ding/framework/middleware/league_learner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,84 @@
import os
import logging
from dataclasses import dataclass
from threading import Lock
from time import sleep
from typing import TYPE_CHECKING, Callable, Optional

from ding.framework import task, EventEnum
from ding.framework.storage import Storage, FileStorage
from ding.league.player import PlayerMeta
from ding.worker.learner.base_learner import BaseLearner

if TYPE_CHECKING:
from ding.framework import Context
from ding.framework.middleware.league_actor import ActorData
from ding.league import ActivePlayer


@dataclass
class LearnerModel:
player_id: str
state_dict: dict
train_iter: int = 0


class LeagueLearner:

def __init__(self, cfg: dict, policy_fn: Callable, player: "ActivePlayer") -> None:
self.cfg = cfg
self.policy_fn = policy_fn
self.player = player
self.player_id = player.player_id
self.checkpoint_prefix = cfg.policy.other.league.path_policy
self._learner = self._get_learner()
self._lock = Lock()
task.on(EventEnum.ACTOR_SEND_DATA.format(player=self.player_id), self._on_actor_send_data)
self._step = 0

def _on_actor_send_data(self, actor_data: "ActorData"):
print("receive data from actor!")
with self._lock:
cfg = self.cfg
for _ in range(cfg.policy.learn.update_per_collect):
print("train model")
self._learner.train(actor_data.train_data, actor_data.env_step)

self.player.total_agent_step = self._learner.train_iter
print("save checkpoint")
checkpoint = self._save_checkpoint() if self.player.is_trained_enough() else None
task.emit(
EventEnum.LEARNER_SEND_META,
PlayerMeta(player_id=self.player_id, checkpoint=checkpoint, total_agent_step=self._learner.train_iter)
)

print("pack model")
learner_model = LearnerModel(
player_id=self.player_id, state_dict=self._learner.policy.state_dict(), train_iter=self._learner.train_iter
)
task.emit(EventEnum.LEARNER_SEND_MODEL, learner_model)

def _get_learner(self) -> BaseLearner:
policy = self.policy_fn().learn_mode
learner = BaseLearner(
self.cfg.policy.learn.learner,
policy,
exp_name=self.cfg.exp_name,
instance_name=self.player_id + '_learner'
)
return learner

def _save_checkpoint(self) -> Optional[Storage]:
if not os.path.exists(self.checkpoint_prefix):
os.makedirs(self.checkpoint_prefix)
storage = FileStorage(
path=os.path.
join(self.checkpoint_prefix, "{}_{}_ckpt.pth".format(self.player_id, self._learner.train_iter))
)
storage.save(self._learner.policy.state_dict())
return storage

def __call__(self, _: "Context") -> None:
sleep(1)
logging.info("{} Step: {}".format(self.__class__, self._step))
self._step += 1
3 changes: 2 additions & 1 deletion ding/framework/middleware/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mock_for_test import MockEnv, MockPolicy, MockHerRewardModel, CONFIG
from .mock_for_test import *
from .league_config import cfg
8 changes: 4 additions & 4 deletions ding/framework/middleware/tests/league_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
'cfg_type': 'BaseEnvManagerDict',
'shared_memory': False
},
'collector_env_num': 2,
'evaluator_env_num': 10,
'collector_env_num': 1,
'evaluator_env_num': 1,
'n_evaluator_episode': 100,
'env_type': 'prisoner_dilemma',
'stop_value': [-10.1, -5.05]
Expand Down Expand Up @@ -46,7 +46,7 @@
},
'multi_gpu': False,
'epoch_per_collect': 10,
'batch_size': 32,
'batch_size': 16,
'learning_rate': 1e-05,
'value_weight': 0.5,
'entropy_weight': 0.0,
Expand Down Expand Up @@ -99,7 +99,7 @@
},
'league': {
'player_category': ['default'],
'path_policy': 'league_demo/policy',
'path_policy': 'league_demo/ckpt',
'active_players': {
'main_player': 2
},
Expand Down
Loading

0 comments on commit 400a19d

Please sign in to comment.