forked from gao-yuan-hangzhou/pytorch-a2c-ppo-acktr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
enjoy.py
78 lines (59 loc) · 2.54 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import os
import numpy as np
import torch
from envs import VecPyTorch, make_vec_envs
from utils import get_render_func, get_vec_normalize
parser = argparse.ArgumentParser(description='RL')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10,
help='log interval, one log per n updates (default: 10)')
parser.add_argument('--env-name', default='PongNoFrameskip-v4',
help='environment to train on (default: PongNoFrameskip-v4)')
parser.add_argument('--load-dir', default='./trained_models/',
help='directory to save agent logs (default: ./trained_models/)')
parser.add_argument('--add-timestep', action='store_true', default=False,
help='add timestep to observations')
parser.add_argument('--non-det', action='store_true', default=False,
help='whether to use a non-deterministic policy')
args = parser.parse_args()
args.det = not args.non_det
env = make_vec_envs(args.env_name, args.seed + 1000, 1,
None, None, args.add_timestep, device='cpu',
allow_early_resets=False)
# Get a render function
render_func = get_render_func(env)
# We need to use the same statistics for normalization as used in training
actor_critic, ob_rms = \
torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
vec_norm = get_vec_normalize(env)
if vec_norm is not None:
vec_norm.eval()
vec_norm.ob_rms = ob_rms
recurrent_hidden_states = torch.zeros(1, actor_critic.recurrent_hidden_state_size)
masks = torch.zeros(1, 1)
if render_func is not None:
render_func('human')
obs = env.reset()
if args.env_name.find('Bullet') > -1:
import pybullet as p
torsoId = -1
for i in range(p.getNumBodies()):
if (p.getBodyInfo(i)[0].decode() == "torso"):
torsoId = i
while True:
with torch.no_grad():
value, action, _, recurrent_hidden_states = actor_critic.act(
obs, recurrent_hidden_states, masks, deterministic=args.det)
# Obser reward and next obs
obs, reward, done, _ = env.step(action)
masks.fill_(0.0 if done else 1.0)
if args.env_name.find('Bullet') > -1:
if torsoId > -1:
distance = 5
yaw = 0
humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
if render_func is not None:
render_func('human')