-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathmain_inference.py
62 lines (54 loc) · 2.12 KB
/
main_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import numpy as np
import tensorflow as tf
from simpler_env.evaluation.argparse import get_args
from simpler_env.evaluation.maniskill2_evaluator import maniskill2_evaluator
from simpler_env.policies.octo.octo_server_model import OctoServerInference
from simpler_env.policies.rt1.rt1_model import RT1Inference
try:
from simpler_env.policies.octo.octo_model import OctoInference
except ImportError as e:
print("Octo is not correctly imported.")
print(e)
if __name__ == "__main__":
args = get_args()
os.environ["DISPLAY"] = ""
# prevent a single jax process from taking up all the GPU memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
gpus = tf.config.list_physical_devices("GPU")
if len(gpus) > 0:
# prevent a single tf process from taking up all the GPU memory
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=args.tf_memory_limit)],
)
# policy model creation; update this if you are using a new policy model
if args.policy_model == "rt1":
assert args.ckpt_path is not None
model = RT1Inference(
saved_model_path=args.ckpt_path,
policy_setup=args.policy_setup,
action_scale=args.action_scale,
)
elif "octo" in args.policy_model:
if args.ckpt_path is None or args.ckpt_path == "None":
args.ckpt_path = args.policy_model
if "server" in args.policy_model:
model = OctoServerInference(
model_type=args.ckpt_path,
policy_setup=args.policy_setup,
action_scale=args.action_scale,
)
else:
model = OctoInference(
model_type=args.ckpt_path,
policy_setup=args.policy_setup,
init_rng=args.octo_init_rng,
action_scale=args.action_scale,
)
else:
raise NotImplementedError()
# run real-to-sim evaluation
success_arr = maniskill2_evaluator(model, args)
print(args)
print(" " * 10, "Average success", np.mean(success_arr))