-
Notifications
You must be signed in to change notification settings - Fork 57
/
calvin_env_wrapper.py
110 lines (95 loc) · 4.51 KB
/
calvin_env_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
import os
from typing import Any, Dict, Tuple, Union
from calvin_agent.datasets.utils.episode_utils import process_depth, process_rgb, process_state
import gym
import numpy as np
import torch
from calvin_env.envs.play_table_env import get_env
from calvin_env.utils.utils import EglDeviceNotFoundError, get_egl_device_id
logger = logging.getLogger(__name__)
class CalvinEnvWrapper(gym.Wrapper):
def __init__(self, dataset_loader, device, show_gui=False, **kwargs):
self.set_egl_device(device)
env = get_env(
dataset_loader.abs_datasets_dir, show_gui=show_gui, obs_space=dataset_loader.observation_space, **kwargs
)
super(CalvinEnvWrapper, self).__init__(env)
self.observation_space_keys = dataset_loader.observation_space
self.transforms = dataset_loader.transforms
self.proprio_state = dataset_loader.proprio_state
self.device = device
self.relative_actions = "rel_actions" in self.observation_space_keys["actions"]
logger.info(f"Initialized PlayTableEnv for device {self.device}")
@staticmethod
def set_egl_device(device):
if "EGL_VISIBLE_DEVICES" in os.environ:
logger.warning("Environment variable EGL_VISIBLE_DEVICES is already set. Is this intended?")
cuda_id = device.index if device.type == "cuda" else 0
try:
egl_id = get_egl_device_id(cuda_id)
except EglDeviceNotFoundError:
logger.warning(
"Couldn't find correct EGL device. Setting EGL_VISIBLE_DEVICE=0. "
"When using DDP with many GPUs this can lead to OOM errors. "
"Did you install PyBullet correctly? Please refer to calvin env README"
)
egl_id = 0
os.environ["EGL_VISIBLE_DEVICES"] = str(egl_id)
logger.info(f"EGL_DEVICE_ID {egl_id} <==> CUDA_DEVICE_ID {cuda_id}")
def transform_observation(self, obs: Dict[str, Any]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
state_obs = process_state(obs, self.observation_space_keys, self.transforms, self.proprio_state)
rgb_obs = process_rgb(obs["rgb_obs"], self.observation_space_keys, self.transforms)
depth_obs = process_depth(obs["depth_obs"], self.observation_space_keys, self.transforms)
state_obs["robot_obs"] = state_obs["robot_obs"].to(self.device).unsqueeze(0)
rgb_obs.update({"rgb_obs": {k: v.to(self.device).unsqueeze(0) for k, v in rgb_obs["rgb_obs"].items()}})
depth_obs.update({"depth_obs": {k: v.to(self.device).unsqueeze(0) for k, v in depth_obs["depth_obs"].items()}})
obs_dict: Dict = {
**rgb_obs,
**state_obs,
**depth_obs,
"robot_obs_raw": torch.from_numpy(obs["robot_obs"]).to(self.device),
}
return obs_dict
def step(
self, action_tensor: torch.Tensor
) -> Tuple[Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]], int, bool, Dict]:
if self.relative_actions:
action = action_tensor.squeeze().cpu().detach().numpy()
assert len(action) == 7
else:
if action_tensor.shape[-1] == 7:
slice_ids = [3, 6]
elif action_tensor.shape[-1] == 8:
slice_ids = [3, 7]
else:
logger.error("actions are required to have length 8 (for euler angles) or 9 (for quaternions)")
raise NotImplementedError
action = np.split(action_tensor.squeeze().cpu().detach().numpy(), slice_ids)
action[-1] = 1 if action[-1] > 0 else -1
o, r, d, i = self.env.step(action)
obs = self.transform_observation(o)
return obs, r, d, i
def reset(
self,
reset_info: Dict[str, Any] = None,
batch_idx: int = 0,
seq_idx: int = 0,
scene_obs: Any = None,
robot_obs: Any = None,
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
if reset_info is not None:
obs = self.env.reset(
robot_obs=reset_info["robot_obs"][batch_idx, seq_idx],
scene_obs=reset_info["scene_obs"][batch_idx, seq_idx],
)
elif scene_obs is not None or robot_obs is not None:
obs = self.env.reset(scene_obs=scene_obs, robot_obs=robot_obs)
else:
obs = self.env.reset()
return self.transform_observation(obs)
def get_info(self):
return self.env.get_info()
def get_obs(self):
obs = self.env.get_obs()
return self.transform_observation(obs)