Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(cxy): add cliffwalking env #677

Merged
merged 4 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ Have fun with exploration and exploitation.
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Feature](#feature)
- [↳ Algorithm Versatility](#algorithm-versatility)
- [↳ Environment Versatility](#environment-versatility)
- [↳ General Data Container: TreeTensor](#general-data-container-treetensor)
- [Algorithm Versatility](#algorithm-versatility)
- [Environment Versatility](#environment-versatility)
- [General Data Container: TreeTensor](#general-data-container-treetensor)
- [Feedback and Contribution](#feedback-and-contribution)
- [Supporters](#supporters)
- [↳ Stargazers](#-stargazers)
- [↳ Forkers](#-forkers)
- [ Stargazers](#-stargazers)
- [ Forkers](#-forkers)
- [Citation](#citation)
- [License](#license)

Expand Down Expand Up @@ -299,6 +299,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/acrobot_zh.html) |
| 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) <br> ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)<br>环境指南 |
| 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
| 36 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)<br> 环境指南 |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
Empty file added dizoo/cliffwalking/__init__.py
Empty file.
Binary file added dizoo/cliffwalking/cliff_walking.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions dizoo/cliffwalking/config/cliffwalking_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from easydict import EasyDict

cliffwalking_dqn_config = dict(
exp_name='cliffwalking_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=-13, # the optimal value of cliffwalking env
max_episode_steps=300,
),
policy=dict(
cuda=True,
load_path="./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=48,
action_shape=4,
encoder_hidden_size_list=[512, 64],
),
discount_factor=0.99,
nstep=1,
learn=dict(
update_per_collect=10,
batch_size=128,
learning_rate=0.001,
target_update_freq=100,
),
collect=dict(
n_sample=64,
unroll_len=1,
),
other=dict(
eps=dict(
type='linear',
start=0.95,
end=0.25,
decay=50000,
),
replay_buffer=dict(replay_buffer_size=100000, )
),
),
)
cliffwalking_dqn_config = EasyDict(cliffwalking_dqn_config)
main_config = cliffwalking_dqn_config

cliffwalking_dqn_create_config = dict(
env=dict(
type='cliffwalking',
import_names=['dizoo.cliffwalking.envs.cliffwalking_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
cliffwalking_dqn_create_config = EasyDict(cliffwalking_dqn_create_config)
create_config = cliffwalking_dqn_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c cliffwalking_dqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline([main_config, create_config], seed=0)
39 changes: 39 additions & 0 deletions dizoo/cliffwalking/entry/cliffwalking_dqn_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import gym
import torch
from easydict import EasyDict

from ding.config import compile_config
from ding.envs import DingEnvWrapper
from ding.model import DQN
from ding.policy import DQNPolicy, single_env_forward_wrapper
from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config
from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv


def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str):
main_config.exp_name = f'cliffwalking_dqn_seed0_deploy'
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
env = CliffWalkingEnv(cfg.env.spec)
env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video')
model = DQN(**cfg.policy.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
policy = DQNPolicy(cfg.policy, model=model).eval_mode
forward_fn = single_env_forward_wrapper(policy.forward)
obs = env.reset()
returns = 0.
while True:
action = forward_fn(obs)
obs, rew, done, info = env.step(action)
returns += rew
if done:
break
print(f'Deploy is finished, final epsiode return is: {returns}')


if __name__ == "__main__":
main(
main_config=main_config,
create_config=create_config,
ckpt_path=f'./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar'
)
50 changes: 50 additions & 0 deletions dizoo/cliffwalking/entry/cliffwalking_dqn_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import gym
from ditk import logging

from ding.config import compile_config
from ding.data import DequeBuffer
from ding.envs import BaseEnvManagerV2, DingEnvWrapper
from ding.framework import ding_init, task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import CkptSaver, OffPolicyLearner, StepCollector, data_pusher, eps_greedy_handler, \
interaction_evaluator, online_logger
from ding.model import DQN
from ding.policy import DQNPolicy
from ding.utils import set_pkg_seed
from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config
from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv


def main():
filename = '{}/log.txt'.format(main_config.exp_name)
logging.getLogger(with_files=[filename]).setLevel(logging.INFO)

cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)

collector_env = BaseEnvManagerV2(
env_fn=[lambda: CliffWalkingEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: CliffWalkingEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
buffer = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNPolicy(cfg.policy, model=model)

with task.start(async_mode=False, ctx=OnlineRLContext()):
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(data_pusher(cfg, buffer))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer))
task.use(online_logger(train_show_freq=10))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.run()


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions dizoo/cliffwalking/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cliffwalking_env import CliffWalkingEnv
106 changes: 106 additions & 0 deletions dizoo/cliffwalking/envs/cliffwalking_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import copy
from typing import List, Union, Optional

import gym
import numpy as np
from easydict import EasyDict

from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY


@ENV_REGISTRY.register('cliffwalking')
class CliffWalkingEnv(BaseEnv):

def __init__(self, cfg: dict) -> None:
self._cfg = EasyDict(
env_id='CliffWalking',
render_mode='rgb_array',
max_episode_steps=300, # default max trajectory length to truncate possible infinite attempts
)
self._cfg.update(cfg)
self._init_flag = False
self._replay_path = None
self._observation_space = gym.spaces.Box(low=0, high=1, shape=(48, ), dtype=np.float32)
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make(
"CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
dy_seed = self._seed + 100 * np.random.randint(1, 1000)
self._env.seed(dy_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
if self._replay_path is not None:
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix='cliffwalking-{}'.format(id(self))
)
obs = self._env.reset()
obs_encode = self._encode_obs(obs)
self._eval_episode_return = 0.
return obs_encode

def close(self) -> None:
try:
self._env.close()
del self._env
except:
pass

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(seed)

def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
if isinstance(action, np.ndarray) and action.shape == (1, ):
action = action.squeeze() # 0-dim array
obs, reward, done, info = self._env.step(action)
obs_encode = self._encode_obs(obs)
self._eval_episode_return += reward
reward = to_ndarray([reward], dtype=np.float32)
if done:
info['eval_episode_return'] = self._eval_episode_return
return BaseEnvTimestep(obs_encode, reward, done, info)

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path

def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
if isinstance(random_action, int):
random_action = to_ndarray([random_action], dtype=np.int64)
return random_action

def _encode_obs(self, obs) -> np.ndarray:
onehot = np.zeros(48, dtype=np.float32)
onehot[int(obs)] = 1
return onehot

@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gym.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space

def __repr__(self) -> str:
return "DI-engine CliffWalking Env"
34 changes: 34 additions & 0 deletions dizoo/cliffwalking/envs/test_cliffwalking_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import pytest
from dizoo.cliffwalking.envs import CliffWalkingEnv

@pytest.mark.envtest
class TestCliffWalkingEnv:

def test_naive(self):
env = CliffWalkingEnv({})
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (48, )
for _ in range(5):
env.reset()
np.random.seed(314)
print('=' * 60)
for i in range(10):
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (48, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
print(env.observation_space, env.action_space, env.reward_space)
env.close()