-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(luyd): add collector logging in new pipeline #735
Changes from 6 commits
b7abfe9
a0423c7
6fe690b
df9d99a
11c6cde
01f47cd
0ccab83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from typing import TYPE_CHECKING, Callable, List, Tuple, Any | ||
from typing import TYPE_CHECKING, Callable, List, Tuple, Any, Optional | ||
from functools import reduce | ||
import treetensor.torch as ttorch | ||
import numpy as np | ||
from ding.utils import EasyTimer, build_logger | ||
from ding.envs import BaseEnvManager | ||
from ding.policy import Policy | ||
from ding.torch_utils import to_ndarray, get_shape0 | ||
|
@@ -83,7 +85,15 @@ def _inference(ctx: "OnlineRLContext"): | |
return _inference | ||
|
||
|
||
def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList) -> Callable: | ||
def rolloutor( | ||
policy: Policy, | ||
env: BaseEnvManager, | ||
transitions: TransitionList, | ||
collect_print_freq=100, | ||
tb_logger=None, | ||
exp_name: Optional[str] = 'default_experiment', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no tensorboard |
||
instance_name: Optional[str] = 'collector' | ||
) -> Callable: | ||
""" | ||
Overview: | ||
The middleware that executes the transition process in the env. | ||
|
@@ -98,6 +108,59 @@ def rolloutor(policy: Policy, env: BaseEnvManager, transitions: TransitionList) | |
|
||
env_episode_id = [_ for _ in range(env.env_num)] | ||
current_id = env.env_num | ||
timer = EasyTimer() | ||
last_train_iter = 0 | ||
total_envstep_count = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no tensorboard |
||
total_episode_count = 0 | ||
total_duration = 0 | ||
total_train_sample_count = 0 | ||
env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)} | ||
episode_info = [] | ||
|
||
if tb_logger is not None: | ||
logger, _ = build_logger(path='./{}/log/{}'.format(exp_name, instance_name), name=instance_name, need_tb=False) | ||
tb_logger = tb_logger | ||
else: | ||
logger, tb_logger = build_logger(path='./{}/log/{}'.format(exp_name, instance_name), name=instance_name) | ||
|
||
def output_log(train_iter: int) -> None: | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move to outside of function |
||
Overview: | ||
Print the output log information. You can refer to the docs of `Best Practice` to understand \ | ||
the training generated logs and tensorboards. | ||
Arguments: | ||
- train_iter (:obj:`int`): the number of training iteration. | ||
""" | ||
nonlocal episode_info, timer, total_episode_count, total_duration, \ | ||
total_envstep_count, total_train_sample_count, last_train_iter | ||
if (train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0: | ||
last_train_iter = train_iter | ||
episode_count = len(episode_info) | ||
envstep_count = sum([d['step'] for d in episode_info]) | ||
train_sample_count = sum([d['train_sample'] for d in episode_info]) | ||
duration = sum([d['time'] for d in episode_info]) | ||
episode_return = [d['reward'].item() for d in episode_info] | ||
print(episode_return) | ||
info = { | ||
'episode_count': episode_count, | ||
'envstep_count': envstep_count, | ||
'train_sample_count': train_sample_count, | ||
'avg_envstep_per_episode': envstep_count / episode_count, | ||
'avg_sample_per_episode': train_sample_count / episode_count, | ||
'avg_envstep_per_sec': envstep_count / duration, | ||
'avg_train_sample_per_sec': train_sample_count / duration, | ||
'avg_episode_per_sec': episode_count / duration, | ||
'reward_mean': np.mean(episode_return), | ||
'reward_std': np.std(episode_return), | ||
'reward_max': np.max(episode_return), | ||
'reward_min': np.min(episode_return), | ||
'total_envstep_count': total_envstep_count, | ||
'total_train_sample_count': total_train_sample_count, | ||
'total_episode_count': total_episode_count, | ||
# 'each_reward': episode_return, | ||
} | ||
episode_info.clear() | ||
logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) | ||
|
||
def _rollout(ctx: "OnlineRLContext"): | ||
""" | ||
|
@@ -113,22 +176,52 @@ def _rollout(ctx: "OnlineRLContext"): | |
trajectory stops. | ||
""" | ||
|
||
nonlocal current_id | ||
nonlocal current_id, env_info, episode_info, timer, \ | ||
total_episode_count, total_duration, total_envstep_count, total_train_sample_count, last_train_iter | ||
timesteps = env.step(ctx.action) | ||
ctx.env_step += len(timesteps) | ||
timesteps = [t.tensor() for t in timesteps] | ||
# TODO abnormal env step | ||
|
||
collected_sample = 0 | ||
collected_step = 0 | ||
collected_episode = 0 | ||
interaction_duration = timer.value / len(timesteps) | ||
for i, timestep in enumerate(timesteps): | ||
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) | ||
transition = ttorch.as_tensor(transition) # TBD | ||
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) | ||
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) | ||
transitions.append(timestep.env_id, transition) | ||
with timer: | ||
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) | ||
transition = ttorch.as_tensor(transition) # TBD | ||
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) | ||
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) | ||
transitions.append(timestep.env_id, transition) | ||
|
||
collected_step += 1 | ||
collected_sample += len(transition.obs) | ||
env_info[timestep.env_id.item()]['step'] += 1 | ||
env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs) | ||
|
||
env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration | ||
if timestep.done: | ||
policy.reset([timestep.env_id]) | ||
env_episode_id[timestep.env_id] = current_id | ||
info = { | ||
'reward': timestep.info['eval_episode_return'], | ||
'time': env_info[timestep.env_id.item()]['time'], | ||
'step': env_info[timestep.env_id.item()]['step'], | ||
'train_sample': env_info[timestep.env_id.item()]['train_sample'], | ||
} | ||
|
||
episode_info.append(info) | ||
policy.reset([timestep.env_id.item()]) | ||
env_episode_id[timestep.env_id.item()] = current_id | ||
collected_episode += 1 | ||
current_id += 1 | ||
ctx.env_episode += 1 | ||
# TODO log | ||
|
||
collected_duration = sum([d['time'] for d in episode_info]) | ||
total_envstep_count += collected_step | ||
total_episode_count += collected_episode | ||
total_duration += collected_duration | ||
total_train_sample_count += collected_sample | ||
|
||
output_log(ctx.train_iter) | ||
|
||
return _rollout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing