Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): polish ppof agent code for opendilab huggingface #730

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 65 additions & 71 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,38 @@
from ding.policy import PPOFPolicy


def get_instance_config(env: str, algorithm: str) -> EasyDict:
def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
if algorithm == 'PPOF':
cfg = PPOFPolicy.default_config()
if env == 'lunarlander_discrete':
if env_id == 'LunarLander-v2':
cfg.n_sample = 512
cfg.value_norm = 'popart'
cfg.entropy_weight = 1e-3
elif env == 'lunarlander_continuous':
elif env_id == 'LunarLanderContinuous-v2':
cfg.action_space = 'continuous'
cfg.n_sample = 400
elif env == 'bipedalwalker':
elif env_id == 'BipedalWalker-v3':
cfg.learning_rate = 1e-3
cfg.action_space = 'continuous'
cfg.n_sample = 1024
elif env == 'acrobot':
elif env_id == 'acrobot':
cfg.learning_rate = 1e-4
cfg.n_sample = 400
elif env == 'rocket_landing':
elif env_id == 'rocket_landing':
cfg.n_sample = 2048
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'drone_fly':
elif env_id == 'drone_fly':
cfg.action_space = 'continuous'
cfg.adv_norm = False
cfg.epoch_per_collect = 5
cfg.learning_rate = 5e-5
cfg.n_sample = 640
elif env == 'hybrid_moving':
elif env_id == 'hybrid_moving':
cfg.action_space = 'hybrid'
cfg.n_sample = 3200
cfg.entropy_weight = 0.03
Expand All @@ -50,13 +50,13 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
fixed_sigma_value=0.3,
bound_type='tanh',
)
elif env == 'evogym_carrier':
elif env_id == 'evogym_carrier':
cfg.action_space = 'continuous'
cfg.n_sample = 2048
cfg.batch_size = 256
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-3
elif env == 'mario':
elif env_id == 'mario':
cfg.n_sample = 256
cfg.batch_size = 64
cfg.epoch_per_collect = 2
Expand All @@ -66,14 +66,14 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=128,
actor_head_hidden_size=128,
)
elif env == 'di_sheep':
elif env_id == 'di_sheep':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
cfg.adv_norm = False
cfg.entropy_weight = 0.001
elif env == 'procgen_bigfish':
elif env_id == 'procgen_bigfish':
cfg.n_sample = 16384
cfg.batch_size = 16384
cfg.epoch_per_collect = 10
Expand All @@ -83,7 +83,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=256,
actor_head_hidden_size=256,
)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling']:
elif env_id in ['KangarooNoFrameskip-v4', 'BowlingNoFrameskip-v4']:
cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
Expand All @@ -94,7 +94,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env == 'PongNoFrameskip':
elif env_id == 'PongNoFrameskip-v4':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
Expand All @@ -104,7 +104,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'SpaceInvadersNoFrameskip':
elif env_id == 'SpaceInvadersNoFrameskip-v4':
cfg.n_sample = 320
cfg.batch_size = 320
cfg.epoch_per_collect = 1
Expand All @@ -116,7 +116,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'QbertNoFrameskip':
elif env_id == 'QbertNoFrameskip-v4':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
Expand All @@ -127,13 +127,13 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'minigrid_fourroom':
elif env_id == 'minigrid_fourroom':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.learning_rate = 3e-4
cfg.epoch_per_collect = 10
cfg.entropy_weight = 0.001
elif env == 'metadrive':
elif env_id == 'metadrive':
cfg.learning_rate = 3e-4
cfg.action_space = 'continuous'
cfg.entropy_weight = 0.001
Expand All @@ -146,49 +146,61 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env in ['hopper']:
elif env_id == 'Hopper-v3':
cfg.action_space = "continuous"
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'HalfCheetah-v3':
cfg.action_space = "continuous"
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'Walker2d-v3':
cfg.action_space = "continuous"
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
else:
raise KeyError("not supported env type: {}".format(env))
raise KeyError("not supported env type: {}".format(env_id))
else:
raise KeyError("not supported algorithm type: {}".format(algorithm))

return cfg


def get_instance_env(env: str) -> BaseEnv:
if env == 'lunarlander_discrete':
def get_instance_env(env_id: str) -> BaseEnv:
if env_id == 'LunarLander-v2':
return DingEnvWrapper(gym.make('LunarLander-v2'))
elif env == 'lunarlander_continuous':
return DingEnvWrapper(gym.make('LunarLander-v2', continuous=True))
elif env == 'bipedalwalker':
elif env_id == 'LunarLanderContinuous-v2':
return DingEnvWrapper(gym.make('LunarLanderContinuous-v2', continuous=True))
elif env_id == 'BipedalWalker-v3':
return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True, 'rew_clip': True})
elif env == 'pendulum':
elif env_id == 'Pendulum-v1':
return DingEnvWrapper(gym.make('Pendulum-v1'), cfg={'act_scale': True})
elif env == 'acrobot':
elif env_id == 'acrobot':
return DingEnvWrapper(gym.make('Acrobot-v1'))
elif env == 'rocket_landing':
elif env_id == 'rocket_landing':
from dizoo.rocket.envs import RocketEnv
cfg = EasyDict({
'task': 'landing',
'max_steps': 800,
})
return RocketEnv(cfg)
elif env == 'drone_fly':
elif env_id == 'drone_fly':
from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
cfg = EasyDict({
'env_id': 'flythrugate-aviary-v0',
'action_type': 'VEL',
})
return GymPybulletDronesEnv(cfg)
elif env == 'hybrid_moving':
elif env_id == 'hybrid_moving':
import gym_hybrid
return DingEnvWrapper(gym.make('Moving-v0'))
elif env == 'evogym_carrier':
elif env_id == 'evogym_carrier':
import evogym.envs
from evogym import sample_robot, WorldObject
path = os.path.join(os.path.dirname(__file__), '../../dizoo/evogym/envs/world_data/carry_bot.json')
Expand All @@ -203,7 +215,7 @@ def get_instance_env(env: str) -> BaseEnv:
]
}
)
elif env == 'mario':
elif env_id == 'mario':
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
return DingEnvWrapper(
Expand All @@ -219,10 +231,10 @@ def get_instance_env(env: str) -> BaseEnv:
]
}
)
elif env == 'di_sheep':
elif env_id == 'di_sheep':
from sheep_env import SheepEnv
return DingEnvWrapper(SheepEnv(level=9))
elif env == 'procgen_bigfish':
elif env_id == 'procgen_bigfish':
return DingEnvWrapper(
gym.make('procgen:procgen-bigfish-v0', start_level=0, num_levels=1),
cfg={
Expand All @@ -234,66 +246,48 @@ def get_instance_env(env: str) -> BaseEnv:
},
seed_api=False,
)
elif env == 'hopper':
elif env_id == 'Hopper-v3':
cfg = EasyDict(
env_id='Hopper-v3',
env_wrapper='mujoco_default',
act_scale=True,
rew_clip=True,
)
return DingEnvWrapper(gym.make('Hopper-v3'), cfg=cfg)
elif env == 'HalfCheetah':
elif env_id == 'HalfCheetah-v3':
cfg = EasyDict(
env_id='HalfCheetah-v3',
env_wrapper='mujoco_default',
act_scale=True,
rew_clip=True,
)
return DingEnvWrapper(gym.make('HalfCheetah-v3'), cfg=cfg)
elif env == 'Walker2d':
elif env_id == 'Walker2d-v3':
cfg = EasyDict(
env_id='Walker2d-v3',
env_wrapper='mujoco_default',
act_scale=True,
rew_clip=True,
)
return DingEnvWrapper(gym.make('Walker2d-v3'), cfg=cfg)
elif env == "SpaceInvadersNoFrameskip":
cfg = EasyDict({
'env_id': "SpaceInvadersNoFrameskip-v4",
'env_wrapper': 'atari_default',
})
return DingEnvWrapper(gym.make("SpaceInvadersNoFrameskip-v4"), cfg=cfg)
elif env == "PongNoFrameskip":
cfg = EasyDict({
'env_id': "PongNoFrameskip-v4",
'env_wrapper': 'atari_default',
})
return DingEnvWrapper(gym.make("PongNoFrameskip-v4"), cfg=cfg)
elif env == "QbertNoFrameskip":
cfg = EasyDict({
'env_id': "QbertNoFrameskip-v4",
'env_wrapper': 'atari_default',
})
return DingEnvWrapper(gym.make("QbertNoFrameskip-v4"), cfg=cfg)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling', 'atari_breakout', 'atari_spaceinvader',
'atari_gopher']:
from dizoo.atari.envs.atari_env import AtariEnv
atari_env_list = {
'atari_qbert': 'QbertNoFrameskip-v4',
'atari_kangaroo': 'KangarooNoFrameskip-v4',
'atari_bowling': 'BowlingNoFrameskip-v4',
'atari_breakout': 'BreakoutNoFrameskip-v4',
'atari_spaceinvader': 'SpaceInvadersNoFrameskip-v4',
'atari_gopher': 'GopherNoFrameskip-v4'
}

elif env_id in [
'BowlingNoFrameskip-v4',
'BreakoutNoFrameskip-v4',
'GopherNoFrameskip-v4'
'KangarooNoFrameskip-v4',
'PongNoFrameskip-v4',
'QbertNoFrameskip-v4',
'SpaceInvadersNoFrameskip-v4',
]:

cfg = EasyDict({
'env_id': atari_env_list[env],
'env_id': env_id,
'env_wrapper': 'atari_default',
})
ding_env_atari = DingEnvWrapper(gym.make(atari_env_list[env]), cfg=cfg)
ding_env_atari = DingEnvWrapper(gym.make(env_id), cfg=cfg)
return ding_env_atari
elif env == 'minigrid_fourroom':
elif env_id == 'minigrid_fourroom':
import gymnasium
return DingEnvWrapper(
gymnasium.make('MiniGrid-FourRooms-v0'),
Expand All @@ -306,7 +300,7 @@ def get_instance_env(env: str) -> BaseEnv:
]
}
)
elif env == 'metadrive':
elif env_id == 'metadrive':
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
cfg = dict(
Expand All @@ -319,7 +313,7 @@ def get_instance_env(env: str) -> BaseEnv:
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
else:
raise KeyError("not supported env type: {}".format(env))
raise KeyError("not supported env type: {}".format(env_id))


def get_hybrid_shape(action_space) -> EasyDict:
Expand Down
Loading
Loading