diff --git a/README.md b/README.md index aea65f743b..34b25c7339 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)
环境指南 | | 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)
环境指南 | | 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)
![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)
环境指南 | +| 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)
环境指南 | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space diff --git a/dizoo/metadrive/__init__.py b/dizoo/metadrive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/metadrive/config/__init__.py b/dizoo/metadrive/config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/metadrive/config/metadrive_onppo_config.py b/dizoo/metadrive/config/metadrive_onppo_config.py new file mode 100644 index 0000000000..7a891da636 --- /dev/null +++ b/dizoo/metadrive/config/metadrive_onppo_config.py @@ -0,0 +1,111 @@ +from easydict import EasyDict +from functools import partial +from tensorboardX import SummaryWriter +import metadrive +import gym +from ding.envs import BaseEnvManager, SyncSubprocessEnvManager +from ding.config import compile_config +from ding.model.template import QAC, VAC +from ding.policy import PPOPolicy +from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner +from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv +from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper + +metadrive_basic_config = dict( + exp_name='metadrive_onppo_seed0', + env=dict( + metadrive=dict( + use_render=False, + traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1] + map='XSOS', # Int or string: an easy way to fill map_config + horizon=4000, # Max step number + driving_reward=1.0, # Reward to encourage agent to move forward. + speed_reward=0.1, # Reward to encourage agent to drive at a high speed + use_lateral_reward=False, # reward for lane keeping + out_of_road_penalty=40.0, # Penalty to discourage driving out of road + crash_vehicle_penalty=40.0, # Penalty to discourage collision + decision_repeat=20, # Reciprocal of decision frequency + out_of_route_done=True, # Game over if driving out of road + ), + manager=dict( + shared_memory=False, + max_retry=2, + context='spawn', + ), + n_evaluator_episode=16, + stop_value=255, + collector_env_num=8, + evaluator_env_num=8, + ), + policy=dict( + cuda=True, + action_space='continuous', + model=dict( + obs_shape=[5, 84, 84], + action_shape=2, + action_space='continuous', + bound_type='tanh', + encoder_hidden_size_list=[128, 128, 64], + ), + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + entropy_weight=0.001, + value_weight=0.5, + clip_ratio=0.02, + adv_norm=False, + value_norm=True, + grad_clip_value=10, + ), + collect=dict(n_sample=3000, ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) +main_config = EasyDict(metadrive_basic_config) + + +def wrapped_env(env_cfg, wrapper_cfg=None): + return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg) + + +def main(cfg): + cfg = compile_config( + cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator + ) + collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num + collector_env = SyncSubprocessEnvManager( + env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)], + cfg=cfg.env.manager, + ) + evaluator_env = SyncSubprocessEnvManager( + env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)], + cfg=cfg.env.manager, + ) + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name)) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = SampleSerialCollector( + cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + learner.call_hook('before_run') + while True: + if evaluator.should_eval(learner.train_iter): + stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Sampling data from environments + new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter) + learner.train(new_data, collector.envstep) + learner.call_hook('after_run') + collector.close() + evaluator.close() + learner.close() + + +if __name__ == '__main__': + main(main_config) diff --git a/dizoo/metadrive/config/metadrive_onppo_eval_config.py b/dizoo/metadrive/config/metadrive_onppo_eval_config.py new file mode 100644 index 0000000000..c9dab89ed2 --- /dev/null +++ b/dizoo/metadrive/config/metadrive_onppo_eval_config.py @@ -0,0 +1,96 @@ +from easydict import EasyDict +from functools import partial +from tensorboardX import SummaryWriter +import torch +from ding.envs import BaseEnvManager, SyncSubprocessEnvManager +from ding.config import compile_config +from ding.model.template import VAC +from ding.policy import PPOPolicy +from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner +from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv +from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper + +# Load the trained model from this direction, if None, it will initialize from scratch +model_dir = None +metadrive_basic_config = dict( + exp_name='metadrive_onppo_eval_seed0', + env=dict( + metadrive=dict( + use_render=True, + traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1] + map='XSOS', # Int or string: an easy way to fill map_config + horizon=4000, # Max step number + driving_reward=1.0, # Reward to encourage agent to move forward. + speed_reward=0.10, # Reward to encourage agent to drive at a high speed + use_lateral_reward=False, # reward for lane keeping + out_of_road_penalty=40.0, # Penalty to discourage driving out of road + crash_vehicle_penalty=40.0, # Penalty to discourage collision + decision_repeat=20, # Reciprocal of decision frequency + out_of_route_done=True, # Game over if driving out of road + show_bird_view=False, # Only used to evaluate, whether to draw five channels of bird-view image + ), + manager=dict( + shared_memory=False, + max_retry=2, + context='spawn', + ), + n_evaluator_episode=16, + stop_value=255, + collector_env_num=1, + evaluator_env_num=1, + ), + policy=dict( + cuda=True, + action_space='continuous', + model=dict( + obs_shape=[5, 84, 84], + action_shape=2, + action_space='continuous', + bound_type='tanh', + encoder_hidden_size_list=[128, 128, 64], + ), + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + entropy_weight=0.001, + value_weight=0.5, + clip_ratio=0.02, + adv_norm=False, + value_norm=True, + grad_clip_value=10, + ), + collect=dict(n_sample=1000, ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) +main_config = EasyDict(metadrive_basic_config) + + +def wrapped_env(env_cfg, wrapper_cfg=None): + return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg) + + +def main(cfg): + cfg = compile_config(cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator) + evaluator_env_num = cfg.env.evaluator_env_num + show_bird_view = cfg.env.metadrive.show_bird_view + wrapper_cfg = {'show_bird_view': show_bird_view} + evaluator_env = BaseEnvManager( + env_fn=[partial(wrapped_env, cfg.env.metadrive, wrapper_cfg) for _ in range(evaluator_env_num)], + cfg=cfg.env.manager, + ) + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + if model_dir is not None: + policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu')) + tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name)) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + stop, rate = evaluator.eval() + evaluator.close() + + +if __name__ == '__main__': + main(main_config) diff --git a/dizoo/metadrive/env/__init__.py b/dizoo/metadrive/env/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/metadrive/env/drive_env.py b/dizoo/metadrive/env/drive_env.py new file mode 100644 index 0000000000..b14676adfe --- /dev/null +++ b/dizoo/metadrive/env/drive_env.py @@ -0,0 +1,341 @@ +import copy +import numpy as np +from ditk import logging +from typing import Union, Dict, AnyStr, Tuple, Optional +from gym.envs.registration import register +from metadrive.manager.traffic_manager import TrafficMode +from metadrive.obs.top_down_obs_multi_channel import TopDownMultiChannel +from metadrive.constants import RENDER_MODE_NONE, DEFAULT_AGENT, REPLAY_DONE +from metadrive.envs.base_env import BaseEnv +from metadrive.component.map.base_map import BaseMap +from metadrive.component.map.pg_map import parse_map_config, MapGenerateMethod +from metadrive.manager.traffic_manager import TrafficMode +from metadrive.component.pgblock.first_block import FirstPGBlock +from metadrive.constants import DEFAULT_AGENT, TerminationState +from metadrive.component.vehicle.base_vehicle import BaseVehicle +from metadrive.utils import Config, merge_dicts, get_np_random, clip +from metadrive.envs.base_env import BASE_DEFAULT_CONFIG +from metadrive.obs.top_down_obs_multi_channel import TopDownMultiChannel +from metadrive.component.road_network import Road +from metadrive.component.algorithm.blocks_prob_dist import PGBlockDistConfig + +METADRIVE_DEFAULT_CONFIG = dict( + # ===== Generalization ===== + start_seed=0, + environment_num=10, + decision_repeat=20, + block_dist_config=PGBlockDistConfig, + + # ===== Map Config ===== + map=3, # int or string: an easy way to fill map_config + random_lane_width=False, + random_lane_num=False, + map_config={ + BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_NUM, + BaseMap.GENERATE_CONFIG: None, # it can be a file path / block num / block ID sequence + BaseMap.LANE_WIDTH: 3.5, + BaseMap.LANE_NUM: 3, + "exit_length": 50, + }, + + # ===== Traffic ===== + traffic_density=0.1, + need_inverse_traffic=False, + traffic_mode=TrafficMode.Trigger, # "Respawn", "Trigger" + random_traffic=False, # Traffic is randomized at default. + traffic_vehicle_config=dict( + show_navi_mark=False, + show_dest_mark=False, + enable_reverse=False, + show_lidar=False, + show_lane_line_detector=False, + show_side_detector=False, + ), + + # ===== Object ===== + accident_prob=0., # accident may happen on each block with this probability, except multi-exits block + + # ===== Others ===== + use_AI_protector=False, + save_level=0.5, + is_multi_agent=False, + vehicle_config=dict(spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0)), + + # ===== Agent ===== + random_spawn_lane_index=True, + target_vehicle_configs={ + DEFAULT_AGENT: dict( + use_special_color=True, + spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0), + ) + }, + + # ===== Reward Scheme ===== + # See: https://github.com/decisionforce/metadrive/issues/283 + success_reward=10.0, + out_of_road_penalty=5.0, + crash_vehicle_penalty=5.0, + crash_object_penalty=5.0, + driving_reward=1.0, + speed_reward=0.1, + use_lateral_reward=False, + + # ===== Cost Scheme ===== + crash_vehicle_cost=1.0, + crash_object_cost=1.0, + out_of_road_cost=1.0, + + # ===== Termination Scheme ===== + out_of_route_done=False, + on_screen=False, + show_bird_view=False, +) + + +class MetaDrivePPOOriginEnv(BaseEnv): + + @classmethod + def default_config(cls) -> "Config": + config = super(MetaDrivePPOOriginEnv, cls).default_config() + config.update(METADRIVE_DEFAULT_CONFIG) + config.register_type("map", str, int) + config["map_config"].register_type("config", None) + return config + + def __init__(self, config: dict = None): + self.default_config_copy = Config(self.default_config(), unchangeable=True) + super(MetaDrivePPOOriginEnv, self).__init__(config) + self.start_seed = self.config["start_seed"] + self.env_num = self.config["environment_num"] + + def _merge_extra_config(self, config: Union[dict, "Config"]) -> "Config": + config = self.default_config().update(config, allow_add_new_key=False) + if config["vehicle_config"]["lidar"]["distance"] > 50: + config["max_distance"] = config["vehicle_config"]["lidar"]["distance"] + return config + + def _post_process_config(self, config): + config = super(MetaDrivePPOOriginEnv, self)._post_process_config(config) + if not config["rgb_clip"]: + logging.warning( + "You have set rgb_clip = False, which means the observation will be uint8 values in [0, 255]. " + "Please make sure you have parsed them later before feeding them to network!" + ) + config["map_config"] = parse_map_config( + easy_map_config=config["map"], new_map_config=config["map_config"], default_config=self.default_config_copy + ) + config["vehicle_config"]["rgb_clip"] = config["rgb_clip"] + config["vehicle_config"]["random_agent_model"] = config["random_agent_model"] + if config.get("gaussian_noise", 0) > 0: + assert config["vehicle_config"]["lidar"]["gaussian_noise"] == 0, "You already provide config!" + assert config["vehicle_config"]["side_detector"]["gaussian_noise"] == 0, "You already provide config!" + assert config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] == 0, "You already provide config!" + config["vehicle_config"]["lidar"]["gaussian_noise"] = config["gaussian_noise"] + config["vehicle_config"]["side_detector"]["gaussian_noise"] = config["gaussian_noise"] + config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] = config["gaussian_noise"] + if config.get("dropout_prob", 0) > 0: + assert config["vehicle_config"]["lidar"]["dropout_prob"] == 0, "You already provide config!" + assert config["vehicle_config"]["side_detector"]["dropout_prob"] == 0, "You already provide config!" + assert config["vehicle_config"]["lane_line_detector"]["dropout_prob"] == 0, "You already provide config!" + config["vehicle_config"]["lidar"]["dropout_prob"] = config["dropout_prob"] + config["vehicle_config"]["side_detector"]["dropout_prob"] = config["dropout_prob"] + config["vehicle_config"]["lane_line_detector"]["dropout_prob"] = config["dropout_prob"] + target_v_config = copy.deepcopy(config["vehicle_config"]) + if not config["is_multi_agent"]: + target_v_config.update(config["target_vehicle_configs"][DEFAULT_AGENT]) + config["target_vehicle_configs"][DEFAULT_AGENT] = target_v_config + return config + + def step(self, actions: Union[np.ndarray, Dict[AnyStr, np.ndarray]]): + actions = self._preprocess_actions(actions) + engine_info = self._step_simulator(actions) + o, r, d, i = self._get_step_return(actions, engine_info=engine_info) + return o, r, d, i + + def _get_observations(self): + return {DEFAULT_AGENT: self.get_single_observation(self.config["vehicle_config"])} + + def cost_function(self, vehicle_id: str): + vehicle = self.vehicles[vehicle_id] + step_info = dict() + step_info["cost"] = 0 + if self._is_out_of_road(vehicle): + step_info["cost"] = self.config["out_of_road_cost"] + elif vehicle.crash_vehicle: + step_info["cost"] = self.config["crash_vehicle_cost"] + elif vehicle.crash_object: + step_info["cost"] = self.config["crash_object_cost"] + return step_info['cost'], step_info + + def _is_out_of_road(self, vehicle): + ret = vehicle.on_yellow_continuous_line or vehicle.on_white_continuous_line or \ + (not vehicle.on_lane) or vehicle.crash_sidewalk + if self.config["out_of_route_done"]: + ret = ret or vehicle.out_of_route + return ret + + def done_function(self, vehicle_id: str): + vehicle = self.vehicles[vehicle_id] + done = False + done_info = { + TerminationState.CRASH_VEHICLE: False, + TerminationState.CRASH_OBJECT: False, + TerminationState.CRASH_BUILDING: False, + TerminationState.OUT_OF_ROAD: False, + TerminationState.SUCCESS: False, + TerminationState.MAX_STEP: False, + TerminationState.ENV_SEED: self.current_seed, + } + if self._is_arrive_destination(vehicle): + done = True + logging.info("Episode ended! Reason: arrive_dest.") + done_info[TerminationState.SUCCESS] = True + if self._is_out_of_road(vehicle): + done = True + logging.info("Episode ended! Reason: out_of_road.") + done_info[TerminationState.OUT_OF_ROAD] = True + if vehicle.crash_vehicle: + done = True + logging.info("Episode ended! Reason: crash vehicle ") + done_info[TerminationState.CRASH_VEHICLE] = True + if vehicle.crash_object: + done = True + done_info[TerminationState.CRASH_OBJECT] = True + logging.info("Episode ended! Reason: crash object ") + if vehicle.crash_building: + done = True + done_info[TerminationState.CRASH_BUILDING] = True + logging.info("Episode ended! Reason: crash building ") + if self.config["max_step_per_agent"] is not None and \ + self.episode_lengths[vehicle_id] >= self.config["max_step_per_agent"]: + done = True + done_info[TerminationState.MAX_STEP] = True + logging.info("Episode ended! Reason: max step ") + + if self.config["horizon"] is not None and \ + self.episode_lengths[vehicle_id] >= self.config["horizon"] and not self.is_multi_agent: + # single agent horizon has the same meaning as max_step_per_agent + done = True + done_info[TerminationState.MAX_STEP] = True + logging.info("Episode ended! Reason: max step ") + + done_info[TerminationState.CRASH] = ( + done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT] + or done_info[TerminationState.CRASH_BUILDING] + ) + return done, done_info + + def reward_function(self, vehicle_id: str): + """ + Override this func to get a new reward function + :param vehicle_id: id of BaseVehicle + :return: reward + """ + vehicle = self.vehicles[vehicle_id] + step_info = dict() + + # Reward for moving forward in current lane + if vehicle.lane in vehicle.navigation.current_ref_lanes: + current_lane = vehicle.lane + positive_road = 1 + else: + current_lane = vehicle.navigation.current_ref_lanes[0] + current_road = vehicle.navigation.current_road + positive_road = 1 if not current_road.is_negative_road() else -1 + long_last, _ = current_lane.local_coordinates(vehicle.last_position) + long_now, lateral_now = current_lane.local_coordinates(vehicle.position) + + # reward for lane keeping, without it vehicle can learn to overtake but fail to keep in lane + if self.config["use_lateral_reward"]: + lateral_factor = clip(1 - 2 * abs(lateral_now) / vehicle.navigation.get_current_lane_width(), 0.0, 1.0) + else: + lateral_factor = 1.0 + + reward = 0.0 + reward += self.config["driving_reward"] * (long_now - long_last) * lateral_factor * positive_road + reward += self.config["speed_reward"] * (vehicle.speed / vehicle.max_speed) * positive_road + + step_info["step_reward"] = reward + + if self._is_arrive_destination(vehicle): + reward = +self.config["success_reward"] + elif self._is_out_of_road(vehicle): + reward = -self.config["out_of_road_penalty"] + elif vehicle.crash_vehicle: + reward = -self.config["crash_vehicle_penalty"] + elif vehicle.crash_object: + reward = -self.config["crash_object_penalty"] + return reward, step_info + + def _get_reset_return(self): + ret = {} + self.engine.after_step() + for v_id, v in self.vehicles.items(): + self.observations[v_id].reset(self, v) + ret[v_id] = self.observations[v_id].observe(v) + return ret if self.is_multi_agent else self._wrap_as_single_agent(ret) + + def switch_to_third_person_view(self) -> (str, BaseVehicle): + if self.main_camera is None: + return + self.main_camera.reset() + if self.config["prefer_track_agent"] is not None and self.config["prefer_track_agent"] in self.vehicles.keys(): + new_v = self.vehicles[self.config["prefer_track_agent"]] + current_track_vehicle = new_v + else: + if self.main_camera.is_bird_view_camera(): + current_track_vehicle = self.current_track_vehicle + else: + vehicles = list(self.engine.agents.values()) + if len(vehicles) <= 1: + return + if self.current_track_vehicle in vehicles: + vehicles.remove(self.current_track_vehicle) + new_v = get_np_random().choice(vehicles) + current_track_vehicle = new_v + self.main_camera.track(current_track_vehicle) + return + + def switch_to_top_down_view(self): + self.main_camera.stop_track() + + def setup_engine(self): + super(MetaDrivePPOOriginEnv, self).setup_engine() + self.engine.accept("b", self.switch_to_top_down_view) + self.engine.accept("q", self.switch_to_third_person_view) + from metadrive.manager.traffic_manager import TrafficManager + from metadrive.manager.map_manager import MapManager + self.engine.register_manager("map_manager", MapManager()) + self.engine.register_manager("traffic_manager", TrafficManager()) + + def _is_arrive_destination(self, vehicle): + long, lat = vehicle.navigation.final_lane.local_coordinates(vehicle.position) + flag = (vehicle.navigation.final_lane.length - 5 < long < vehicle.navigation.final_lane.length + 5) and ( + vehicle.navigation.get_current_lane_width() / 2 >= lat >= + (0.5 - vehicle.navigation.get_current_lane_num()) * vehicle.navigation.get_current_lane_width() + ) + return flag + + def _reset_global_seed(self, force_seed=None): + """ + Current seed is set to force seed if force_seed is not None. + Otherwise, current seed is randomly generated. + """ + current_seed = force_seed if force_seed is not None else \ + get_np_random(self._DEBUG_RANDOM_SEED).randint(self.start_seed, self.start_seed + self.env_num) + self.seed(current_seed) + + def _get_observations(self): + return {DEFAULT_AGENT: self.get_single_observation(self.config["vehicle_config"])} + + def get_single_observation(self, _=None): + return TopDownMultiChannel( + self.config["vehicle_config"], + self.config["on_screen"], + self.config["rgb_clip"], + frame_stack=3, + post_stack=10, + frame_skip=1, + resolution=(84, 84), + max_distance=36, + ) diff --git a/dizoo/metadrive/env/drive_utils.py b/dizoo/metadrive/env/drive_utils.py new file mode 100644 index 0000000000..99415c9b93 --- /dev/null +++ b/dizoo/metadrive/env/drive_utils.py @@ -0,0 +1,121 @@ +from typing import NoReturn, Optional, List +from gym import utils +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from easydict import EasyDict +from itertools import product +import gym +import copy +import numpy as np +import matplotlib.pyplot as plt +from ding.utils.default_helper import deep_merge_dicts + + +class AAA(): + + def __init__(self) -> None: + self.x = 0 + + +def deep_update( + original: dict, + new_dict: dict, + new_keys_allowed: bool = False, + whitelist: Optional[List[str]] = None, + override_all_if_type_changes: Optional[List[str]] = None +): + """ + Overview: + Updates original dict with values from new_dict recursively. + + .. note:: + + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the whitelist, then new subkeys can be introduced. + + Arguments: + - original (:obj:`dict`): Dictionary with default values. + - new_dict (:obj:`dict`): Dictionary with values to be updated + - new_keys_allowed (:obj:`bool`): Whether new keys are allowed. + - whitelist (Optional[List[str]]): List of keys that correspond to dict + values where new subkeys can be introduced. This is only at the top + level. + - override_all_if_type_changes(Optional[List[str]]): List of top level + keys with value=dict, for which we always simply override the + entire value (:obj:`dict`), if the "type" key in that value dict changes. + """ + whitelist = whitelist or [] + override_all_if_type_changes = override_all_if_type_changes or [] + for k, value in new_dict.items(): + if k not in original and not new_keys_allowed: + raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) + # Both original value and new one are dicts. + if isinstance(original.get(k), dict) and isinstance(value, dict): + # Check old type vs old one. If different, override entire value. + if k in override_all_if_type_changes and \ + "type" in value and "type" in original[k] and \ + value["type"] != original[k]["type"]: + original[k] = value + # Whitelisted key -> ok to add new subkeys. + elif k in whitelist: + deep_update(original[k], value, True) + # Non-whitelisted key. + else: + deep_update(original[k], value, new_keys_allowed) + # Original value not a dict OR new value not a dict: + # Override entire value. + else: + original[k] = value + return original + + +class BaseDriveEnv(gym.Env, utils.EzPickle): + config = dict() + + @abstractmethod + def __init__(self, cfg: Dict, **kwargs) -> None: + if 'cfg_type' not in cfg: + self._cfg = self.__class__.default_config() + self._cfg = deep_merge_dicts(self._cfg, cfg) + else: + self._cfg = cfg + utils.EzPickle.__init__(self) + + @abstractmethod + def step(self, action: Any) -> Any: + """ + Run one step of the environment and return the observation dict. + """ + raise NotImplementedError + + @abstractmethod + def reset(self, *args, **kwargs) -> Any: + """ + Reset current environment. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """ + Release all resources in environment and close. + """ + raise NotImplementedError + + @abstractmethod + def seed(self, seed: int) -> None: + """ + Set random seed. + """ + raise NotImplementedError + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(cls.config) + cfg.cfg_type = cls.__name__ + 'Config' + return copy.deepcopy(cfg) + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError diff --git a/dizoo/metadrive/env/drive_wrapper.py b/dizoo/metadrive/env/drive_wrapper.py new file mode 100644 index 0000000000..5b48826ab8 --- /dev/null +++ b/dizoo/metadrive/env/drive_wrapper.py @@ -0,0 +1,145 @@ +from typing import Any, Dict, Optional +from easydict import EasyDict +from itertools import product +from typing import NoReturn, Optional, List +import matplotlib.pyplot as plt +import gym +import copy +import numpy as np +from ding.envs.env.base_env import BaseEnvTimestep +from ding.envs.common.env_element import EnvElementInfo +from ding.torch_utils.data_helper import to_ndarray +from ding.utils.default_helper import deep_merge_dicts +from dizoo.metadrive.env.drive_utils import BaseDriveEnv + + +def draw_multi_channels_top_down_observation(obs, show_time=0.5): + num_channels = obs.shape[-1] + assert num_channels == 5 + channel_names = [ + "Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1", + "Neighbor at step t-2" + ] + fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80) + count = 0 + + def close_event(): + plt.close() + + timer = fig.canvas.new_timer(interval=show_time * 1000) + timer.add_callback(close_event) + for i, name in enumerate(channel_names): + count += 1 + ax = axs[i] + ax.imshow(obs[..., i], cmap="bone") + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(name) + fig.suptitle("Multi-channels Top-down Observation") + timer.start() + plt.show() + plt.close() + + +class DriveEnvWrapper(gym.Wrapper): + """ + Overview: + Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine. + It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered. + + Arguments: + - env (BaseDriveEnv): The environment to be wrapped. + - cfg (Dict): Config dict. + """ + config = dict() + + def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None: + if cfg is None: + self._cfg = self.__class__.default_config() + elif 'cfg_type' not in cfg: + self._cfg = self.__class__.default_config() + self._cfg = deep_merge_dicts(self._cfg, cfg) + else: + self._cfg = cfg + self.env = env + if not hasattr(self.env, 'reward_space'): + self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, )) + if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True: + self.show_bird_view = True + else: + self.show_bird_view = False + self.action_space = self.env.action_space + self.env = env + + def reset(self, *args, **kwargs) -> Any: + """ + Overview: + Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward + are recorded. + Returns: + - Any: Observations from environment + """ + obs = self.env.reset(*args, **kwargs) + obs = to_ndarray(obs, dtype=np.float32) + if isinstance(obs, np.ndarray) and len(obs.shape) == 3: + obs = obs.transpose((2, 0, 1)) + elif isinstance(obs, dict): + vehicle_state = obs['vehicle_state'] + birdview = obs['birdview'].transpose((2, 0, 1)) + obs = {'vehicle_state': vehicle_state, 'birdview': birdview} + self._final_eval_reward = 0.0 + self._arrive_dest = False + return obs + + def step(self, action: Any = None) -> BaseEnvTimestep: + """ + Overview: + Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into + that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep`` + namedtuple defined in DI-engine. It will also convert actions, observations and reward into + ``np.ndarray``, and check legality if action contains control signal. + Arguments: + - action (Any, optional): Actions sent to env. Defaults to None. + Returns: + - BaseEnvTimestep: DI-engine format of env step returns. + """ + action = to_ndarray(action) + obs, rew, done, info = self.env.step(action) + if self.show_bird_view: + draw_multi_channels_top_down_observation(obs, show_time=0.5) + self._final_eval_reward += rew + obs = to_ndarray(obs, dtype=np.float32) + if isinstance(obs, np.ndarray) and len(obs.shape) == 3: + obs = obs.transpose((2, 0, 1)) + elif isinstance(obs, dict): + vehicle_state = obs['vehicle_state'] + birdview = obs['birdview'].transpose((2, 0, 1)) + obs = {'vehicle_state': vehicle_state, 'birdview': birdview} + rew = to_ndarray([rew], dtype=np.float32) + if done: + info['final_eval_reward'] = self._final_eval_reward + info['eval_episode_return'] = self._final_eval_reward + return BaseEnvTimestep(obs, rew, done, info) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(cls.config) + cfg.cfg_type = cls.__name__ + 'Config' + return copy.deepcopy(cfg) + + def __repr__(self) -> str: + return repr(self.env) + + def render(self): + self.env.render() diff --git a/dizoo/metadrive/metadrive_env.gif b/dizoo/metadrive/metadrive_env.gif new file mode 100644 index 0000000000..5e78923051 Binary files /dev/null and b/dizoo/metadrive/metadrive_env.gif differ