forked from fshamshirdar/pytorch-rdpg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·64 lines (52 loc) · 3.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
import argparse
from copy import deepcopy
import torch
import gym
from normalized_env import NormalizedEnv
from evaluator import Evaluator
from rdpg import RDPG
from util import *
gym.undo_logger_setup()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch on TORCS with Multi-modal')
parser.add_argument('--mode', default='train', type=str, help='support option: train/test')
parser.add_argument('--env', default='Pendulum-v0', type=str, help='open-ai gym environment')
parser.add_argument('--hidden1', default=400, type=int, help='hidden num of first fully connect layer')
parser.add_argument('--hidden2', default=300, type=int, help='hidden num of second fully connect layer')
parser.add_argument('--rate', default=0.001, type=float, help='learning rate')
parser.add_argument('--prate', default=0.0001, type=float, help='policy net learning rate (only for DDPG)')
parser.add_argument('--warmup', default=10000, type=int, help='time without training but only filling the replay memory')
parser.add_argument('--discount', default=0.99, type=float, help='')
parser.add_argument('--bsize', default=64, type=int, help='minibatch size')
parser.add_argument('--rmsize', default=6000000, type=int, help='memory size')
parser.add_argument('--window_length', default=1, type=int, help='')
parser.add_argument('--tau', default=0.001, type=float, help='moving average for target network')
parser.add_argument('--ou_theta', default=0.15, type=float, help='noise theta')
parser.add_argument('--ou_sigma', default=0.2, type=float, help='noise sigma')
parser.add_argument('--ou_mu', default=0.0, type=float, help='noise mu')
parser.add_argument('--validate_episodes', default=20, type=int, help='how many episode to perform during validate experiment')
parser.add_argument('--max_episode_length', default=500, type=int, help='')
parser.add_argument('--trajectory_length', default=5, type=int, help='')
parser.add_argument('--validate_steps', default=2000, type=int, help='how many steps to perform a validate experiment')
parser.add_argument('--debug', dest='debug', action='store_true')
parser.add_argument('--init_w', default=0.003, type=float, help='')
parser.add_argument('--train_iter', default=20000000, type=int, help='train iters each timestep')
parser.add_argument('--epsilon', default=50000, type=int, help='linear decay of exploration policy')
parser.add_argument('--seed', default=-1, type=int, help='')
parser.add_argument('--checkpoint', default="checkpoints", type=str, help='Checkpoint path')
args = parser.parse_args()
# env = NormalizedEnv(gym.make(args.env))
env = gym.make(args.env)
if args.seed > 0:
np.random.seed(args.seed)
env.seed(args.seed)
nb_states = env.observation_space.shape[0]
nb_actions = env.action_space.shape[0]
rdpg = RDPG(env, nb_states, nb_actions, args)
if args.mode == 'train':
rdpg.train(args.train_iter, args.checkpoint, args.debug)
elif args.mode == 'test':
rdpg.test(args.validate_episodes, args.checkpoint, visualize=True, debug=args.debug)
else:
raise RuntimeError('undefined mode {}'.format(args.mode))