Skip to content

Commit

Permalink
polish(pu): add resume_training option to allow the envstep and train…
Browse files Browse the repository at this point in the history
…_iter resume seamlessly
  • Loading branch information
puyuan1996 committed Oct 29, 2024
1 parent 6b9f509 commit b789608
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 26 deletions.
14 changes: 4 additions & 10 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def serial_pipeline(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand All @@ -60,15 +61,6 @@ def serial_pipeline(
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Load pretrained model if specified
if cfg.policy.load_path is not None:
logging.info(f'Loading model from {cfg.policy.load_path} begin...')
if cfg.policy.cuda and torch.cuda.is_available():
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda'))
else:
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
logging.info(f'Loading model from {cfg.policy.load_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
Expand All @@ -95,6 +87,8 @@ def serial_pipeline(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
12 changes: 9 additions & 3 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def mbrl_entry_setup(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)

if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -81,6 +82,7 @@ def mbrl_entry_setup(
evaluator,
commander,
tb_logger,
resume_training
)


Expand Down Expand Up @@ -125,12 +127,14 @@ def serial_pipeline_dyna(
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)

learner.call_hook('before_run')
if resume_training:
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down Expand Up @@ -198,10 +202,12 @@ def serial_pipeline_dream(
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')
if resume_training:
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down
5 changes: 4 additions & 1 deletion ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def serial_pipeline_ngu(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -89,6 +90,8 @@ def serial_pipeline_ngu(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
14 changes: 4 additions & 10 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def serial_pipeline_onpolicy(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand All @@ -58,15 +59,6 @@ def serial_pipeline_onpolicy(
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Load pretrained model if specified
if cfg.policy.load_path is not None:
logging.info(f'Loading model from {cfg.policy.load_path} begin...')
if cfg.policy.cuda and torch.cuda.is_available():
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda'))
else:
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
logging.info(f'Loading model from {cfg.policy.load_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
Expand All @@ -89,6 +81,8 @@ def serial_pipeline_onpolicy(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
5 changes: 4 additions & 1 deletion ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def serial_pipeline_onpolicy_ppg(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
resume_training = cfg.policy.learn.get('resume_training', False)
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not resume_training)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +81,8 @@ def serial_pipeline_onpolicy_ppg(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if resume_training:
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
10 changes: 10 additions & 0 deletions ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def envstep(self) -> int:
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
10 changes: 10 additions & 0 deletions ding/worker/collector/battle_sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def envstep(self) -> int:
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
10 changes: 10 additions & 0 deletions ding/worker/collector/episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def envstep(self) -> int:
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
10 changes: 10 additions & 0 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ def envstep(self) -> int:
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): the total envstep count
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
22 changes: 22 additions & 0 deletions ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __init__(
self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []}
# Last iteration. Used to record current iter.
self._last_iter = CountVar(init_val=0)
# Collector envstep. Used to record current envstep.
self._collector_envstep = 0

# Setup time wrapper and hook.
self._setup_wrapper()
Expand Down Expand Up @@ -177,6 +179,26 @@ def register_hook(self, hook: LearnerHook) -> None:
"""
add_learner_hook(self._hooks, hook)

@property
def collector_envstep(self) -> int:
"""
Overview:
Get current collector envstep.
Returns:
- collector_envstep (:obj:`int`): Current collector envstep.
"""
return self._collector_envstep

@collector_envstep.setter
def collector_envstep(self, value: int) -> None:
"""
Overview:
Set current collector envstep.
Arguments:
- value (:obj:`int`): Current collector envstep.
"""
self._collector_envstep = value

def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None:
"""
Overview:
Expand Down
4 changes: 4 additions & 0 deletions ding/worker/learner/learner_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
if 'last_iter' in state_dict:
last_iter = state_dict.pop('last_iter')
engine.last_iter.update(last_iter)
if 'last_step' in state_dict:
last_step = state_dict.pop('last_step')
engine._collector_envstep = last_step
engine.policy.load_state_dict(state_dict)
engine.info('{} load ckpt in {}'.format(engine.instance_name, path))

Expand Down Expand Up @@ -166,6 +169,7 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
path = os.path.join(dirname, ckpt_name)
state_dict = engine.policy.state_dict()
state_dict.update({'last_iter': engine.last_iter.val})
state_dict.update({'last_step': engine._collector_envstep})
save_file(path, state_dict)
engine.info('{} save ckpt in {}'.format(engine.instance_name, path))

Expand Down
10 changes: 9 additions & 1 deletion dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
# Path to the pretrained checkpoint (ckpt).
# If set to an empty string (''), no pretrained model will be loaded.
# To load a pretrained ckpt, specify the path like this:
# learner=dict(hook=dict(load_ckpt_before_run='/path/to/your/ckpt/iteration_100.pth.tar')),

# If True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
resume_training=False,
),
collect=dict(
n_sample=256,
Expand Down

0 comments on commit b789608

Please sign in to comment.