Skip to content

Commit

Permalink
fix(nyz): fix cql example entry wrong config bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed May 29, 2023
1 parent 5804402 commit ebae45b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
12 changes: 5 additions & 7 deletions dizoo/atari/entry/pong_cql_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from copy import deepcopy

from dizoo.atari.config.serial.pong.pong_qrdqn_generation_data_config import main_config, create_config
from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline


Expand All @@ -15,22 +14,21 @@ def train_cql(args):


def eval_ckpt(args):
from dizoo.atari.config.serial.pong.pong_qrdqn_generation_data_config import main_config, create_config
main_config.exp_name = 'pong'
main_config.policy.learn.learner.load_path = './pong/ckpt/ckpt_best.pth.tar'
main_config.policy.learn.learner.hook.load_ckpt_before_run = './pong/ckpt/ckpt_best.pth.tar'
config = deepcopy([main_config, create_config])
eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
eval(config, seed=args.seed, load_path='./pong/ckpt/ckpt_best.pth.tar')


def generate(args):
from dizoo.atari.config.serial.pong.pong_qrdqn_generation_data_config import main_config, create_config
main_config.exp_name = 'pong'
main_config.policy.learn.learner.load_path = './pong/ckpt/ckpt_best.pth.tar'
main_config.policy.collect.save_path = './pong/expert.pkl'
config = deepcopy([main_config, create_config])
state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
state_dict = torch.load('./pong/ckpt/ckpt_best.pth.tar', map_location='cpu')
collect_demo_data(
config,
collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
collect_count=int(1e5),
seed=args.seed,
expert_data_path=main_config.policy.collect.save_path,
state_dict=state_dict
Expand Down
14 changes: 7 additions & 7 deletions dizoo/classic_control/cartpole/entry/cartpole_cql_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from copy import deepcopy

from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import main_config, create_config
from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline


Expand All @@ -15,23 +14,23 @@ def train_cql(args):


def eval_ckpt(args):
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import main_config, create_config
main_config, create_config = deepcopy(main_config), deepcopy(create_config)
main_config.exp_name = 'cartpole'
main_config.policy.learn.learner.load_path = './cartpole/ckpt/ckpt_best.pth.tar'
main_config.policy.learn.learner.hook.load_ckpt_before_run = './cartpole/ckpt/ckpt_best.pth.tar'
config = deepcopy([main_config, create_config])
eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
eval(config, seed=args.seed, load_path='./cartpole/ckpt/ckpt_best.pth.tar')


def generate(args):
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import main_config, create_config
main_config.exp_name = 'cartpole'
main_config.policy.learn.learner.load_path = './cartpole/ckpt/ckpt_best.pth.tar'
main_config.policy.collect.save_path = './cartpole/expert.pkl'
main_config.policy.collect.data_type = 'hdf5'
config = deepcopy([main_config, create_config])
state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
state_dict = torch.load('./cartpole/ckpt/ckpt_best.pth.tar', map_location='cpu')
collect_demo_data(
config,
collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
collect_count=10000,
seed=args.seed,
expert_data_path=main_config.policy.collect.save_path,
state_dict=state_dict
Expand All @@ -40,6 +39,7 @@ def generate(args):

def train_expert(args):
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import main_config, create_config
main_config, create_config = deepcopy(main_config), deepcopy(create_config)
main_config.exp_name = 'cartpole'
config = deepcopy([main_config, create_config])
serial_pipeline(config, seed=args.seed)
Expand Down

0 comments on commit ebae45b

Please sign in to comment.