diff --git a/ding/framework/middleware/tests/test_advantage_estimator.py b/ding/framework/middleware/tests/test_advantage_estimator.py index 431b0199a0..66ad45e77d 100644 --- a/ding/framework/middleware/tests/test_advantage_estimator.py +++ b/ding/framework/middleware/tests/test_advantage_estimator.py @@ -33,7 +33,7 @@ def get_attribute(self, name: str) -> Any: def call_gae_estimator(batch_size: int = 32, trajectory_end_idx_size: int = 5, buffer: Optional[Buffer] = None): - cfg = EasyDict({'policy': {'collect': {'discount_factor': 0.9, 'gae_lambda': 0.95}}}) + cfg = EasyDict({'policy': {'collect': {'discount_factor': 0.9, 'gae_lambda': 0.95}, 'cuda': False}}) ctx = OnlineRLContext() assert trajectory_end_idx_size <= batch_size diff --git a/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py b/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py index a6b37549af..76669b057e 100644 --- a/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py @@ -7,7 +7,7 @@ env_id='dmc2gym-v0', domain_name="cartpole", task_name="swingup", - frame_skip=2, + frame_skip=4, warp_frame=True, scale=True, clip_rewards=False, @@ -17,9 +17,6 @@ collector_env_num=8, evaluator_env_num=8, n_evaluator_episode=8, - # collector_env_num=1, - # evaluator_env_num=1, - # n_evaluator_episode=1, stop_value=1e6, manager=dict(shared_memory=False, ), ), @@ -27,36 +24,19 @@ model_type='pixel', cuda=True, random_collect_size=10000, - # random_collect_size=10, model=dict( obs_shape=(3, 84, 84), action_shape=1, twin_critic=True, - encoder_hidden_size_list=[32, 32, 50], + encoder_hidden_size_list=[32, 32, 32], actor_head_hidden_size=1024, critic_head_hidden_size=1024, - - # different option about whether to share_conv_encoder in two Q networks - # and whether to use embed_action - - share_conv_encoder=False, - embed_action=False, - - # share_conv_encoder=True, - # embed_action=False, - - # share_conv_encoder=False, - # embed_action=True, - - # share_conv_encoder=True, - # embed_action=True, - embed_action_density=0.1, + share_encoder=True, ), learn=dict( ignore_done=True, update_per_collect=1, batch_size=128, - # batch_size=4, # debug learning_rate_q=1e-3, learning_rate_policy=1e-3, learning_rate_alpha=3e-4, @@ -70,7 +50,6 @@ n_sample=1, unroll_len=1, ), - command=dict(), eval=dict(), other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ), ), @@ -85,7 +64,6 @@ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'], ), env_manager=dict(type='subprocess'), - # env_manager=dict(type='base'), # debug policy=dict( type='sac', import_names=['ding.policy.sac'], @@ -94,23 +72,3 @@ ) dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config) create_config = dmc2gym_sac_create_config - -# if __name__ == "__main__": -# # or you can enter `ding -m serial -c dmc2gym_sac_pixel_config.py -s 0` -# from ding.entry import serial_pipeline -# serial_pipeline([main_config, create_config], seed=0) - - -if __name__ == "__main__": - import copy - import argparse - from ding.entry import serial_pipeline - - for seed in [0, 1, 2]: - parser = argparse.ArgumentParser() - parser.add_argument('--seed', '-s', type=int, default=seed) - args = parser.parse_args() - - main_config.exp_name = 'dmc2gym_sac_pixel_scet-eat01-detach' + 'seed' + f'{args.seed}' - serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed, - max_env_step=int(3e6)) diff --git a/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py b/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py index 53337e7749..840629423b 100644 --- a/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py @@ -20,7 +20,6 @@ model_type='state', cuda=True, random_collect_size=10000, - load_path="/root/dmc2gym_cartpole_swingup_state_sac_eval/ckpt/ckpt_best.pth.tar", model=dict( obs_shape=5, action_shape=1, @@ -46,7 +45,6 @@ n_sample=1, unroll_len=1, ), - command=dict(), eval=dict(), other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), ), @@ -69,33 +67,3 @@ ) dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config) create_config = dmc2gym_sac_create_config - -# if __name__ == "__main__": -# # or you can enter `ding -m serial -c dmc2gym_sac_state_config.py -s 0` -# from ding.entry import serial_pipeline -# serial_pipeline([main_config, create_config], seed=0) - -# if __name__ == "__main__": -# # or you can enter `ding -m serial -c dmc2gym_sac_config.py -s 0` -# from ding.entry import serial_pipeline -# serial_pipeline([main_config, create_config], seed=0) - -def train(args): - main_config.exp_name = 'dmc2gym_sac_state_old_check/' + 'seed' + f'{args.seed}' + '_5M' - serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed, - max_env_step=int(5e6)) - - -if __name__ == "__main__": - import copy - import argparse - from ding.entry import serial_pipeline - - for seed in [0, 1, 2]: - parser = argparse.ArgumentParser() - parser.add_argument('--seed', '-s', type=int, default=seed) - args = parser.parse_args() - - main_config.exp_name = 'dmc2gym_sac_state' + 'seed' + f'{args.seed}' - serial_pipeline([copy.deepcopy(main_config), copy.deepcopy(create_config)], seed=args.seed, - max_env_step=int(3e6)) diff --git a/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py b/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py index 011db11326..60a83921ef 100644 --- a/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py +++ b/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py @@ -1,5 +1,8 @@ +from tensorboardX import SummaryWriter from ditk import logging -from ding.model.template.qac import QACPixel +import os +import numpy as np +from ding.model.template.qac import QAC from ding.policy import SACPolicy from ding.envs import BaseEnvManagerV2 from ding.data import DequeBuffer @@ -11,9 +14,6 @@ from ding.utils import set_pkg_seed from dizoo.dmc2gym.envs.dmc2gym_env import DMC2GymEnv from dizoo.dmc2gym.config.dmc2gym_sac_pixel_config import main_config, create_config -import numpy as np -from tensorboardX import SummaryWriter -import os def main(): logging.getLogger().setLevel(logging.INFO) @@ -35,7 +35,8 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - model = QACPixel(**cfg.policy.model) + model = QAC(**cfg.policy.model) + logging.info(model) buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) policy = SACPolicy(cfg.policy, model=model) @@ -72,8 +73,8 @@ def _add_train_scalar(ctx): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(_add_train_scalar) - task.use(CkptSaver(cfg, policy, train_freq=100)) - task.use(termination_checker(max_env_step=int(5000000))) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5))) + task.use(termination_checker(max_env_step=int(5e6))) task.run() diff --git a/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py b/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py index 45c3a46410..6bc7036352 100644 --- a/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py +++ b/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py @@ -72,8 +72,8 @@ def _add_train_scalar(ctx): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(_add_train_scalar) - task.use(CkptSaver(cfg, policy, train_freq=100)) - task.use(termination_checker(max_env_step=int(5000000))) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5))) + task.use(termination_checker(max_env_step=int(5e6))) task.run()