Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zt): add metadrive-simulator env and related onppo config #574

Merged
merged 10 commits into from
Feb 15, 2023
98 changes: 98 additions & 0 deletions dizoo/metadrive/config/test_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

timothijoe marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add empty __init__.py in each dir

from easydict import EasyDict
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
from functools import partial
from tensorboardX import SummaryWriter

from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
from ding.config import compile_config
from ding.model.template import VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper

metadrive_basic_config = dict(
exp_name='feb03_test',
env=dict(
metadrive=dict(
use_render = True,
traffic_density=0.10,
map = 'OSXS',
horizon = 4000, #20000
driving_reward = 0.15,
speed_reward = 0.15,
use_lateral_reward=False,
out_of_route_done = True,
),
manager=dict(
shared_memory=False,
max_retry=2,
context='spawn',
),
n_evaluator_episode=16,
stop_value=99999,
collector_env_num=1,
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
evaluator_env_num=1,
),
policy=dict(
cuda=True,
action_space='continuous',
model=dict(
#obs_shape=[5, 200, 200],
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
obs_shape=[5, 84, 84],
action_shape=2,
action_space='continuous',
bound_type='tanh',
encoder_hidden_size_list=[128, 128, 64],
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight = 0.001,
value_weight=0.5,
clip_ratio = 0.02,
adv_norm=False,
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
value_norm=True,
grad_clip_value=10,
),
collect=dict(
n_sample=1000,
),
eval=dict(
evaluator=dict(
eval_freq=1000,
),
),
),
)

main_config = EasyDict(metadrive_basic_config)
def wrapped_env(env_cfg, wrapper_cfg=None):
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)

def main(cfg):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
evaluator_env = BaseEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
cfg=cfg.env.manager,
)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
learner.call_hook('before_run')

stop, rate = evaluator.eval()
evaluator.close()
learner.close()


if __name__ == '__main__':
main(main_config)
119 changes: 119 additions & 0 deletions dizoo/metadrive/config/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import metadrive
import gym
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter

from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
from ding.config import compile_config
from ding.model.template import QAC, VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
metadrive_basic_config = dict(
exp_name='zt_nov22_ppo1',
env=dict(
metadrive=dict(
use_render = False,
Copy link
Collaborator

@puyuan1996 puyuan1996 Feb 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no space near =, execute bash format.sh diizoo/metadrive to reformat the files

traffic_density=0.10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some anatations about the key parameters in metadrive env?

map = 'OSXS',
horizon = 4000, #20000
driving_reward = 0.15,
speed_reward = 0.15,
use_lateral_reward=False,
out_of_route_done = True,
),
manager=dict(
shared_memory=False,
max_retry=2,
context='spawn',
),
n_evaluator_episode=16,
stop_value=99999,
collector_env_num=1,
evaluator_env_num=1,
),
policy=dict(
cuda=True,
action_space='continuous',
model=dict(
#obs_shape=[5, 200, 200],
obs_shape=[5, 84, 84],
action_shape=2,
action_space='continuous',
bound_type='tanh',
encoder_hidden_size_list=[128, 128, 64],
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight = 0.001,
value_weight=0.5,
clip_ratio = 0.02,
adv_norm=False,
value_norm=True,
grad_clip_value=10,
),
collect=dict(
n_sample=1000,
),
eval=dict(
evaluator=dict(
eval_freq=1000,
),
),
),
)

main_config = EasyDict(metadrive_basic_config)

def wrapped_env(env_cfg, wrapper_cfg=None):
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)


def main(cfg):
cfg = compile_config(
cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
)

collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)],
cfg=cfg.env.manager,
)
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
cfg=cfg.env.manager,
)


model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

learner.call_hook('before_run')

while True:
if evaluator.should_eval(learner.train_iter):
stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Sampling data from environments
new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
timothijoe marked this conversation as resolved.
Show resolved Hide resolved
collector.close()
evaluator.close()
learner.close()


if __name__ == '__main__':
main(main_config)
Loading