forked from facebookresearch/TCDM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrollout.py
63 lines (48 loc) · 2.13 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import matplotlib.pyplot as plt
import pickle as pkl
import numpy as np
import glob, yaml, os, imageio, cv2, shutil
from tcdm import suite
from stable_baselines3 import PPO
from argparse import ArgumentParser
"""
PLEASE DOWNLOAD AND UNZIP THE PRE-TRAINED AGENTS BEFORE RUNNING THIS
SEE: https://github.com/facebookresearch/DexMan#pre-trained-policies
"""
parser = ArgumentParser(description="Example code for loading pre-trained policies")
parser.add_argument('--save_folder', default='pretrained_agents/hammer_use1/',
help="Save folder containing agent checkpoint/config")
parser.add_argument('--render', action="store_true", help="Supply flag to render mp4")
def render(writer, physics, AA=2, height=256, width=256):
if writer is None:
return
img = physics.render(camera_id=0, height=height * AA, width=width * AA)
writer.append_data(cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA))
def rollout(save_folder, writer):
# get experiment config
config = yaml.safe_load(open(os.path.join(save_folder, 'exp_config.yaml'), 'r'))
# build environment and load policy
o, t = config['env']['name'].split('-')
env = suite.load(o, t, config['env']['task_kwargs'], gym_wrap=True)
policy = PPO.load(os.path.join(save_folder, 'checkpoint.zip'))
# rollout the policy and print total reward
s, done, total_reward = env.reset(), False, 0
render(writer, env.wrapped.physics)
while not done:
action, _ = policy.predict(s['state'], deterministic=True)
s, r, done, __ = env.step(action)
render(writer, env.wrapped.physics)
total_reward += r
print('Total reward:', total_reward)
if __name__ == "__main__":
args = parser.parse_args()
# configure writer
if args.render:
writer = imageio.get_writer('rollout.mp4', fps=25)
rollout(args.save_folder, writer)
writer.close()
else:
rollout(args.save_folder, None)