diff --git a/README.md b/README.md
index aea65f743b..34b25c7339 100644
--- a/README.md
+++ b/README.md
@@ -285,6 +285,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)
环境指南 |
| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)
环境指南 |
| 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)
![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)
环境指南 |
+| 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)
环境指南 |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
diff --git a/dizoo/metadrive/__init__.py b/dizoo/metadrive/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/dizoo/metadrive/config/__init__.py b/dizoo/metadrive/config/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/dizoo/metadrive/config/metadrive_onppo_config.py b/dizoo/metadrive/config/metadrive_onppo_config.py
new file mode 100644
index 0000000000..7a891da636
--- /dev/null
+++ b/dizoo/metadrive/config/metadrive_onppo_config.py
@@ -0,0 +1,111 @@
+from easydict import EasyDict
+from functools import partial
+from tensorboardX import SummaryWriter
+import metadrive
+import gym
+from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
+from ding.config import compile_config
+from ding.model.template import QAC, VAC
+from ding.policy import PPOPolicy
+from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
+from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
+from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
+
+metadrive_basic_config = dict(
+ exp_name='metadrive_onppo_seed0',
+ env=dict(
+ metadrive=dict(
+ use_render=False,
+ traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1]
+ map='XSOS', # Int or string: an easy way to fill map_config
+ horizon=4000, # Max step number
+ driving_reward=1.0, # Reward to encourage agent to move forward.
+ speed_reward=0.1, # Reward to encourage agent to drive at a high speed
+ use_lateral_reward=False, # reward for lane keeping
+ out_of_road_penalty=40.0, # Penalty to discourage driving out of road
+ crash_vehicle_penalty=40.0, # Penalty to discourage collision
+ decision_repeat=20, # Reciprocal of decision frequency
+ out_of_route_done=True, # Game over if driving out of road
+ ),
+ manager=dict(
+ shared_memory=False,
+ max_retry=2,
+ context='spawn',
+ ),
+ n_evaluator_episode=16,
+ stop_value=255,
+ collector_env_num=8,
+ evaluator_env_num=8,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=[5, 84, 84],
+ action_shape=2,
+ action_space='continuous',
+ bound_type='tanh',
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ value_weight=0.5,
+ clip_ratio=0.02,
+ adv_norm=False,
+ value_norm=True,
+ grad_clip_value=10,
+ ),
+ collect=dict(n_sample=3000, ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+main_config = EasyDict(metadrive_basic_config)
+
+
+def wrapped_env(env_cfg, wrapper_cfg=None):
+ return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)
+
+
+def main(cfg):
+ cfg = compile_config(
+ cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)],
+ cfg=cfg.env.manager,
+ )
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager,
+ )
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ learner.call_hook('before_run')
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Sampling data from environments
+ new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+ learner.call_hook('after_run')
+ collector.close()
+ evaluator.close()
+ learner.close()
+
+
+if __name__ == '__main__':
+ main(main_config)
diff --git a/dizoo/metadrive/config/metadrive_onppo_eval_config.py b/dizoo/metadrive/config/metadrive_onppo_eval_config.py
new file mode 100644
index 0000000000..c9dab89ed2
--- /dev/null
+++ b/dizoo/metadrive/config/metadrive_onppo_eval_config.py
@@ -0,0 +1,96 @@
+from easydict import EasyDict
+from functools import partial
+from tensorboardX import SummaryWriter
+import torch
+from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
+from ding.config import compile_config
+from ding.model.template import VAC
+from ding.policy import PPOPolicy
+from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
+from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
+from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
+
+# Load the trained model from this direction, if None, it will initialize from scratch
+model_dir = None
+metadrive_basic_config = dict(
+ exp_name='metadrive_onppo_eval_seed0',
+ env=dict(
+ metadrive=dict(
+ use_render=True,
+ traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1]
+ map='XSOS', # Int or string: an easy way to fill map_config
+ horizon=4000, # Max step number
+ driving_reward=1.0, # Reward to encourage agent to move forward.
+ speed_reward=0.10, # Reward to encourage agent to drive at a high speed
+ use_lateral_reward=False, # reward for lane keeping
+ out_of_road_penalty=40.0, # Penalty to discourage driving out of road
+ crash_vehicle_penalty=40.0, # Penalty to discourage collision
+ decision_repeat=20, # Reciprocal of decision frequency
+ out_of_route_done=True, # Game over if driving out of road
+ show_bird_view=False, # Only used to evaluate, whether to draw five channels of bird-view image
+ ),
+ manager=dict(
+ shared_memory=False,
+ max_retry=2,
+ context='spawn',
+ ),
+ n_evaluator_episode=16,
+ stop_value=255,
+ collector_env_num=1,
+ evaluator_env_num=1,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=[5, 84, 84],
+ action_shape=2,
+ action_space='continuous',
+ bound_type='tanh',
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ value_weight=0.5,
+ clip_ratio=0.02,
+ adv_norm=False,
+ value_norm=True,
+ grad_clip_value=10,
+ ),
+ collect=dict(n_sample=1000, ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+main_config = EasyDict(metadrive_basic_config)
+
+
+def wrapped_env(env_cfg, wrapper_cfg=None):
+ return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)
+
+
+def main(cfg):
+ cfg = compile_config(cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator)
+ evaluator_env_num = cfg.env.evaluator_env_num
+ show_bird_view = cfg.env.metadrive.show_bird_view
+ wrapper_cfg = {'show_bird_view': show_bird_view}
+ evaluator_env = BaseEnvManager(
+ env_fn=[partial(wrapped_env, cfg.env.metadrive, wrapper_cfg) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager,
+ )
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ if model_dir is not None:
+ policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu'))
+ tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ stop, rate = evaluator.eval()
+ evaluator.close()
+
+
+if __name__ == '__main__':
+ main(main_config)
diff --git a/dizoo/metadrive/env/__init__.py b/dizoo/metadrive/env/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/dizoo/metadrive/env/drive_env.py b/dizoo/metadrive/env/drive_env.py
new file mode 100644
index 0000000000..b14676adfe
--- /dev/null
+++ b/dizoo/metadrive/env/drive_env.py
@@ -0,0 +1,341 @@
+import copy
+import numpy as np
+from ditk import logging
+from typing import Union, Dict, AnyStr, Tuple, Optional
+from gym.envs.registration import register
+from metadrive.manager.traffic_manager import TrafficMode
+from metadrive.obs.top_down_obs_multi_channel import TopDownMultiChannel
+from metadrive.constants import RENDER_MODE_NONE, DEFAULT_AGENT, REPLAY_DONE
+from metadrive.envs.base_env import BaseEnv
+from metadrive.component.map.base_map import BaseMap
+from metadrive.component.map.pg_map import parse_map_config, MapGenerateMethod
+from metadrive.manager.traffic_manager import TrafficMode
+from metadrive.component.pgblock.first_block import FirstPGBlock
+from metadrive.constants import DEFAULT_AGENT, TerminationState
+from metadrive.component.vehicle.base_vehicle import BaseVehicle
+from metadrive.utils import Config, merge_dicts, get_np_random, clip
+from metadrive.envs.base_env import BASE_DEFAULT_CONFIG
+from metadrive.obs.top_down_obs_multi_channel import TopDownMultiChannel
+from metadrive.component.road_network import Road
+from metadrive.component.algorithm.blocks_prob_dist import PGBlockDistConfig
+
+METADRIVE_DEFAULT_CONFIG = dict(
+ # ===== Generalization =====
+ start_seed=0,
+ environment_num=10,
+ decision_repeat=20,
+ block_dist_config=PGBlockDistConfig,
+
+ # ===== Map Config =====
+ map=3, # int or string: an easy way to fill map_config
+ random_lane_width=False,
+ random_lane_num=False,
+ map_config={
+ BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_NUM,
+ BaseMap.GENERATE_CONFIG: None, # it can be a file path / block num / block ID sequence
+ BaseMap.LANE_WIDTH: 3.5,
+ BaseMap.LANE_NUM: 3,
+ "exit_length": 50,
+ },
+
+ # ===== Traffic =====
+ traffic_density=0.1,
+ need_inverse_traffic=False,
+ traffic_mode=TrafficMode.Trigger, # "Respawn", "Trigger"
+ random_traffic=False, # Traffic is randomized at default.
+ traffic_vehicle_config=dict(
+ show_navi_mark=False,
+ show_dest_mark=False,
+ enable_reverse=False,
+ show_lidar=False,
+ show_lane_line_detector=False,
+ show_side_detector=False,
+ ),
+
+ # ===== Object =====
+ accident_prob=0., # accident may happen on each block with this probability, except multi-exits block
+
+ # ===== Others =====
+ use_AI_protector=False,
+ save_level=0.5,
+ is_multi_agent=False,
+ vehicle_config=dict(spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0)),
+
+ # ===== Agent =====
+ random_spawn_lane_index=True,
+ target_vehicle_configs={
+ DEFAULT_AGENT: dict(
+ use_special_color=True,
+ spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0),
+ )
+ },
+
+ # ===== Reward Scheme =====
+ # See: https://github.com/decisionforce/metadrive/issues/283
+ success_reward=10.0,
+ out_of_road_penalty=5.0,
+ crash_vehicle_penalty=5.0,
+ crash_object_penalty=5.0,
+ driving_reward=1.0,
+ speed_reward=0.1,
+ use_lateral_reward=False,
+
+ # ===== Cost Scheme =====
+ crash_vehicle_cost=1.0,
+ crash_object_cost=1.0,
+ out_of_road_cost=1.0,
+
+ # ===== Termination Scheme =====
+ out_of_route_done=False,
+ on_screen=False,
+ show_bird_view=False,
+)
+
+
+class MetaDrivePPOOriginEnv(BaseEnv):
+
+ @classmethod
+ def default_config(cls) -> "Config":
+ config = super(MetaDrivePPOOriginEnv, cls).default_config()
+ config.update(METADRIVE_DEFAULT_CONFIG)
+ config.register_type("map", str, int)
+ config["map_config"].register_type("config", None)
+ return config
+
+ def __init__(self, config: dict = None):
+ self.default_config_copy = Config(self.default_config(), unchangeable=True)
+ super(MetaDrivePPOOriginEnv, self).__init__(config)
+ self.start_seed = self.config["start_seed"]
+ self.env_num = self.config["environment_num"]
+
+ def _merge_extra_config(self, config: Union[dict, "Config"]) -> "Config":
+ config = self.default_config().update(config, allow_add_new_key=False)
+ if config["vehicle_config"]["lidar"]["distance"] > 50:
+ config["max_distance"] = config["vehicle_config"]["lidar"]["distance"]
+ return config
+
+ def _post_process_config(self, config):
+ config = super(MetaDrivePPOOriginEnv, self)._post_process_config(config)
+ if not config["rgb_clip"]:
+ logging.warning(
+ "You have set rgb_clip = False, which means the observation will be uint8 values in [0, 255]. "
+ "Please make sure you have parsed them later before feeding them to network!"
+ )
+ config["map_config"] = parse_map_config(
+ easy_map_config=config["map"], new_map_config=config["map_config"], default_config=self.default_config_copy
+ )
+ config["vehicle_config"]["rgb_clip"] = config["rgb_clip"]
+ config["vehicle_config"]["random_agent_model"] = config["random_agent_model"]
+ if config.get("gaussian_noise", 0) > 0:
+ assert config["vehicle_config"]["lidar"]["gaussian_noise"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["side_detector"]["gaussian_noise"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] == 0, "You already provide config!"
+ config["vehicle_config"]["lidar"]["gaussian_noise"] = config["gaussian_noise"]
+ config["vehicle_config"]["side_detector"]["gaussian_noise"] = config["gaussian_noise"]
+ config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] = config["gaussian_noise"]
+ if config.get("dropout_prob", 0) > 0:
+ assert config["vehicle_config"]["lidar"]["dropout_prob"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["side_detector"]["dropout_prob"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["lane_line_detector"]["dropout_prob"] == 0, "You already provide config!"
+ config["vehicle_config"]["lidar"]["dropout_prob"] = config["dropout_prob"]
+ config["vehicle_config"]["side_detector"]["dropout_prob"] = config["dropout_prob"]
+ config["vehicle_config"]["lane_line_detector"]["dropout_prob"] = config["dropout_prob"]
+ target_v_config = copy.deepcopy(config["vehicle_config"])
+ if not config["is_multi_agent"]:
+ target_v_config.update(config["target_vehicle_configs"][DEFAULT_AGENT])
+ config["target_vehicle_configs"][DEFAULT_AGENT] = target_v_config
+ return config
+
+ def step(self, actions: Union[np.ndarray, Dict[AnyStr, np.ndarray]]):
+ actions = self._preprocess_actions(actions)
+ engine_info = self._step_simulator(actions)
+ o, r, d, i = self._get_step_return(actions, engine_info=engine_info)
+ return o, r, d, i
+
+ def _get_observations(self):
+ return {DEFAULT_AGENT: self.get_single_observation(self.config["vehicle_config"])}
+
+ def cost_function(self, vehicle_id: str):
+ vehicle = self.vehicles[vehicle_id]
+ step_info = dict()
+ step_info["cost"] = 0
+ if self._is_out_of_road(vehicle):
+ step_info["cost"] = self.config["out_of_road_cost"]
+ elif vehicle.crash_vehicle:
+ step_info["cost"] = self.config["crash_vehicle_cost"]
+ elif vehicle.crash_object:
+ step_info["cost"] = self.config["crash_object_cost"]
+ return step_info['cost'], step_info
+
+ def _is_out_of_road(self, vehicle):
+ ret = vehicle.on_yellow_continuous_line or vehicle.on_white_continuous_line or \
+ (not vehicle.on_lane) or vehicle.crash_sidewalk
+ if self.config["out_of_route_done"]:
+ ret = ret or vehicle.out_of_route
+ return ret
+
+ def done_function(self, vehicle_id: str):
+ vehicle = self.vehicles[vehicle_id]
+ done = False
+ done_info = {
+ TerminationState.CRASH_VEHICLE: False,
+ TerminationState.CRASH_OBJECT: False,
+ TerminationState.CRASH_BUILDING: False,
+ TerminationState.OUT_OF_ROAD: False,
+ TerminationState.SUCCESS: False,
+ TerminationState.MAX_STEP: False,
+ TerminationState.ENV_SEED: self.current_seed,
+ }
+ if self._is_arrive_destination(vehicle):
+ done = True
+ logging.info("Episode ended! Reason: arrive_dest.")
+ done_info[TerminationState.SUCCESS] = True
+ if self._is_out_of_road(vehicle):
+ done = True
+ logging.info("Episode ended! Reason: out_of_road.")
+ done_info[TerminationState.OUT_OF_ROAD] = True
+ if vehicle.crash_vehicle:
+ done = True
+ logging.info("Episode ended! Reason: crash vehicle ")
+ done_info[TerminationState.CRASH_VEHICLE] = True
+ if vehicle.crash_object:
+ done = True
+ done_info[TerminationState.CRASH_OBJECT] = True
+ logging.info("Episode ended! Reason: crash object ")
+ if vehicle.crash_building:
+ done = True
+ done_info[TerminationState.CRASH_BUILDING] = True
+ logging.info("Episode ended! Reason: crash building ")
+ if self.config["max_step_per_agent"] is not None and \
+ self.episode_lengths[vehicle_id] >= self.config["max_step_per_agent"]:
+ done = True
+ done_info[TerminationState.MAX_STEP] = True
+ logging.info("Episode ended! Reason: max step ")
+
+ if self.config["horizon"] is not None and \
+ self.episode_lengths[vehicle_id] >= self.config["horizon"] and not self.is_multi_agent:
+ # single agent horizon has the same meaning as max_step_per_agent
+ done = True
+ done_info[TerminationState.MAX_STEP] = True
+ logging.info("Episode ended! Reason: max step ")
+
+ done_info[TerminationState.CRASH] = (
+ done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT]
+ or done_info[TerminationState.CRASH_BUILDING]
+ )
+ return done, done_info
+
+ def reward_function(self, vehicle_id: str):
+ """
+ Override this func to get a new reward function
+ :param vehicle_id: id of BaseVehicle
+ :return: reward
+ """
+ vehicle = self.vehicles[vehicle_id]
+ step_info = dict()
+
+ # Reward for moving forward in current lane
+ if vehicle.lane in vehicle.navigation.current_ref_lanes:
+ current_lane = vehicle.lane
+ positive_road = 1
+ else:
+ current_lane = vehicle.navigation.current_ref_lanes[0]
+ current_road = vehicle.navigation.current_road
+ positive_road = 1 if not current_road.is_negative_road() else -1
+ long_last, _ = current_lane.local_coordinates(vehicle.last_position)
+ long_now, lateral_now = current_lane.local_coordinates(vehicle.position)
+
+ # reward for lane keeping, without it vehicle can learn to overtake but fail to keep in lane
+ if self.config["use_lateral_reward"]:
+ lateral_factor = clip(1 - 2 * abs(lateral_now) / vehicle.navigation.get_current_lane_width(), 0.0, 1.0)
+ else:
+ lateral_factor = 1.0
+
+ reward = 0.0
+ reward += self.config["driving_reward"] * (long_now - long_last) * lateral_factor * positive_road
+ reward += self.config["speed_reward"] * (vehicle.speed / vehicle.max_speed) * positive_road
+
+ step_info["step_reward"] = reward
+
+ if self._is_arrive_destination(vehicle):
+ reward = +self.config["success_reward"]
+ elif self._is_out_of_road(vehicle):
+ reward = -self.config["out_of_road_penalty"]
+ elif vehicle.crash_vehicle:
+ reward = -self.config["crash_vehicle_penalty"]
+ elif vehicle.crash_object:
+ reward = -self.config["crash_object_penalty"]
+ return reward, step_info
+
+ def _get_reset_return(self):
+ ret = {}
+ self.engine.after_step()
+ for v_id, v in self.vehicles.items():
+ self.observations[v_id].reset(self, v)
+ ret[v_id] = self.observations[v_id].observe(v)
+ return ret if self.is_multi_agent else self._wrap_as_single_agent(ret)
+
+ def switch_to_third_person_view(self) -> (str, BaseVehicle):
+ if self.main_camera is None:
+ return
+ self.main_camera.reset()
+ if self.config["prefer_track_agent"] is not None and self.config["prefer_track_agent"] in self.vehicles.keys():
+ new_v = self.vehicles[self.config["prefer_track_agent"]]
+ current_track_vehicle = new_v
+ else:
+ if self.main_camera.is_bird_view_camera():
+ current_track_vehicle = self.current_track_vehicle
+ else:
+ vehicles = list(self.engine.agents.values())
+ if len(vehicles) <= 1:
+ return
+ if self.current_track_vehicle in vehicles:
+ vehicles.remove(self.current_track_vehicle)
+ new_v = get_np_random().choice(vehicles)
+ current_track_vehicle = new_v
+ self.main_camera.track(current_track_vehicle)
+ return
+
+ def switch_to_top_down_view(self):
+ self.main_camera.stop_track()
+
+ def setup_engine(self):
+ super(MetaDrivePPOOriginEnv, self).setup_engine()
+ self.engine.accept("b", self.switch_to_top_down_view)
+ self.engine.accept("q", self.switch_to_third_person_view)
+ from metadrive.manager.traffic_manager import TrafficManager
+ from metadrive.manager.map_manager import MapManager
+ self.engine.register_manager("map_manager", MapManager())
+ self.engine.register_manager("traffic_manager", TrafficManager())
+
+ def _is_arrive_destination(self, vehicle):
+ long, lat = vehicle.navigation.final_lane.local_coordinates(vehicle.position)
+ flag = (vehicle.navigation.final_lane.length - 5 < long < vehicle.navigation.final_lane.length + 5) and (
+ vehicle.navigation.get_current_lane_width() / 2 >= lat >=
+ (0.5 - vehicle.navigation.get_current_lane_num()) * vehicle.navigation.get_current_lane_width()
+ )
+ return flag
+
+ def _reset_global_seed(self, force_seed=None):
+ """
+ Current seed is set to force seed if force_seed is not None.
+ Otherwise, current seed is randomly generated.
+ """
+ current_seed = force_seed if force_seed is not None else \
+ get_np_random(self._DEBUG_RANDOM_SEED).randint(self.start_seed, self.start_seed + self.env_num)
+ self.seed(current_seed)
+
+ def _get_observations(self):
+ return {DEFAULT_AGENT: self.get_single_observation(self.config["vehicle_config"])}
+
+ def get_single_observation(self, _=None):
+ return TopDownMultiChannel(
+ self.config["vehicle_config"],
+ self.config["on_screen"],
+ self.config["rgb_clip"],
+ frame_stack=3,
+ post_stack=10,
+ frame_skip=1,
+ resolution=(84, 84),
+ max_distance=36,
+ )
diff --git a/dizoo/metadrive/env/drive_utils.py b/dizoo/metadrive/env/drive_utils.py
new file mode 100644
index 0000000000..99415c9b93
--- /dev/null
+++ b/dizoo/metadrive/env/drive_utils.py
@@ -0,0 +1,121 @@
+from typing import NoReturn, Optional, List
+from gym import utils
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional
+from easydict import EasyDict
+from itertools import product
+import gym
+import copy
+import numpy as np
+import matplotlib.pyplot as plt
+from ding.utils.default_helper import deep_merge_dicts
+
+
+class AAA():
+
+ def __init__(self) -> None:
+ self.x = 0
+
+
+def deep_update(
+ original: dict,
+ new_dict: dict,
+ new_keys_allowed: bool = False,
+ whitelist: Optional[List[str]] = None,
+ override_all_if_type_changes: Optional[List[str]] = None
+):
+ """
+ Overview:
+ Updates original dict with values from new_dict recursively.
+
+ .. note::
+
+ If new key is introduced in new_dict, then if new_keys_allowed is not
+ True, an error will be thrown. Further, for sub-dicts, if the key is
+ in the whitelist, then new subkeys can be introduced.
+
+ Arguments:
+ - original (:obj:`dict`): Dictionary with default values.
+ - new_dict (:obj:`dict`): Dictionary with values to be updated
+ - new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
+ - whitelist (Optional[List[str]]): List of keys that correspond to dict
+ values where new subkeys can be introduced. This is only at the top
+ level.
+ - override_all_if_type_changes(Optional[List[str]]): List of top level
+ keys with value=dict, for which we always simply override the
+ entire value (:obj:`dict`), if the "type" key in that value dict changes.
+ """
+ whitelist = whitelist or []
+ override_all_if_type_changes = override_all_if_type_changes or []
+ for k, value in new_dict.items():
+ if k not in original and not new_keys_allowed:
+ raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
+ # Both original value and new one are dicts.
+ if isinstance(original.get(k), dict) and isinstance(value, dict):
+ # Check old type vs old one. If different, override entire value.
+ if k in override_all_if_type_changes and \
+ "type" in value and "type" in original[k] and \
+ value["type"] != original[k]["type"]:
+ original[k] = value
+ # Whitelisted key -> ok to add new subkeys.
+ elif k in whitelist:
+ deep_update(original[k], value, True)
+ # Non-whitelisted key.
+ else:
+ deep_update(original[k], value, new_keys_allowed)
+ # Original value not a dict OR new value not a dict:
+ # Override entire value.
+ else:
+ original[k] = value
+ return original
+
+
+class BaseDriveEnv(gym.Env, utils.EzPickle):
+ config = dict()
+
+ @abstractmethod
+ def __init__(self, cfg: Dict, **kwargs) -> None:
+ if 'cfg_type' not in cfg:
+ self._cfg = self.__class__.default_config()
+ self._cfg = deep_merge_dicts(self._cfg, cfg)
+ else:
+ self._cfg = cfg
+ utils.EzPickle.__init__(self)
+
+ @abstractmethod
+ def step(self, action: Any) -> Any:
+ """
+ Run one step of the environment and return the observation dict.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self, *args, **kwargs) -> Any:
+ """
+ Reset current environment.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self) -> None:
+ """
+ Release all resources in environment and close.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def seed(self, seed: int) -> None:
+ """
+ Set random seed.
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(cls.config)
+ cfg.cfg_type = cls.__name__ + 'Config'
+ return copy.deepcopy(cfg)
+
+ @abstractmethod
+ def __repr__(self) -> str:
+ raise NotImplementedError
diff --git a/dizoo/metadrive/env/drive_wrapper.py b/dizoo/metadrive/env/drive_wrapper.py
new file mode 100644
index 0000000000..5b48826ab8
--- /dev/null
+++ b/dizoo/metadrive/env/drive_wrapper.py
@@ -0,0 +1,145 @@
+from typing import Any, Dict, Optional
+from easydict import EasyDict
+from itertools import product
+from typing import NoReturn, Optional, List
+import matplotlib.pyplot as plt
+import gym
+import copy
+import numpy as np
+from ding.envs.env.base_env import BaseEnvTimestep
+from ding.envs.common.env_element import EnvElementInfo
+from ding.torch_utils.data_helper import to_ndarray
+from ding.utils.default_helper import deep_merge_dicts
+from dizoo.metadrive.env.drive_utils import BaseDriveEnv
+
+
+def draw_multi_channels_top_down_observation(obs, show_time=0.5):
+ num_channels = obs.shape[-1]
+ assert num_channels == 5
+ channel_names = [
+ "Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1",
+ "Neighbor at step t-2"
+ ]
+ fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80)
+ count = 0
+
+ def close_event():
+ plt.close()
+
+ timer = fig.canvas.new_timer(interval=show_time * 1000)
+ timer.add_callback(close_event)
+ for i, name in enumerate(channel_names):
+ count += 1
+ ax = axs[i]
+ ax.imshow(obs[..., i], cmap="bone")
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_title(name)
+ fig.suptitle("Multi-channels Top-down Observation")
+ timer.start()
+ plt.show()
+ plt.close()
+
+
+class DriveEnvWrapper(gym.Wrapper):
+ """
+ Overview:
+ Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine.
+ It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered.
+
+ Arguments:
+ - env (BaseDriveEnv): The environment to be wrapped.
+ - cfg (Dict): Config dict.
+ """
+ config = dict()
+
+ def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None:
+ if cfg is None:
+ self._cfg = self.__class__.default_config()
+ elif 'cfg_type' not in cfg:
+ self._cfg = self.__class__.default_config()
+ self._cfg = deep_merge_dicts(self._cfg, cfg)
+ else:
+ self._cfg = cfg
+ self.env = env
+ if not hasattr(self.env, 'reward_space'):
+ self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, ))
+ if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True:
+ self.show_bird_view = True
+ else:
+ self.show_bird_view = False
+ self.action_space = self.env.action_space
+ self.env = env
+
+ def reset(self, *args, **kwargs) -> Any:
+ """
+ Overview:
+ Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward
+ are recorded.
+ Returns:
+ - Any: Observations from environment
+ """
+ obs = self.env.reset(*args, **kwargs)
+ obs = to_ndarray(obs, dtype=np.float32)
+ if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
+ obs = obs.transpose((2, 0, 1))
+ elif isinstance(obs, dict):
+ vehicle_state = obs['vehicle_state']
+ birdview = obs['birdview'].transpose((2, 0, 1))
+ obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
+ self._final_eval_reward = 0.0
+ self._arrive_dest = False
+ return obs
+
+ def step(self, action: Any = None) -> BaseEnvTimestep:
+ """
+ Overview:
+ Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into
+ that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep``
+ namedtuple defined in DI-engine. It will also convert actions, observations and reward into
+ ``np.ndarray``, and check legality if action contains control signal.
+ Arguments:
+ - action (Any, optional): Actions sent to env. Defaults to None.
+ Returns:
+ - BaseEnvTimestep: DI-engine format of env step returns.
+ """
+ action = to_ndarray(action)
+ obs, rew, done, info = self.env.step(action)
+ if self.show_bird_view:
+ draw_multi_channels_top_down_observation(obs, show_time=0.5)
+ self._final_eval_reward += rew
+ obs = to_ndarray(obs, dtype=np.float32)
+ if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
+ obs = obs.transpose((2, 0, 1))
+ elif isinstance(obs, dict):
+ vehicle_state = obs['vehicle_state']
+ birdview = obs['birdview'].transpose((2, 0, 1))
+ obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
+ rew = to_ndarray([rew], dtype=np.float32)
+ if done:
+ info['final_eval_reward'] = self._final_eval_reward
+ info['eval_episode_return'] = self._final_eval_reward
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True)
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(cls.config)
+ cfg.cfg_type = cls.__name__ + 'Config'
+ return copy.deepcopy(cfg)
+
+ def __repr__(self) -> str:
+ return repr(self.env)
+
+ def render(self):
+ self.env.render()
diff --git a/dizoo/metadrive/metadrive_env.gif b/dizoo/metadrive/metadrive_env.gif
new file mode 100644
index 0000000000..5e78923051
Binary files /dev/null and b/dizoo/metadrive/metadrive_env.gif differ