Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix demo collection bugt #840

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions omnigibson/envs/data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def step(self, action):
- bool: truncated, i.e. whether this episode ended due to a time limit etc.
- dict: info, i.e. dictionary with any useful information
"""
# Make sure actions are always flattened numpy arrays
if isinstance(action, dict):
action = np.concatenate([act for act in action.values()])

next_obs, reward, terminated, truncated, info = self.env.step(action)
self.step_count += 1

Expand Down Expand Up @@ -133,21 +137,22 @@ def observation_spec(self):
"""
return self.env.observation_spec()

def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
def process_traj_to_hdf5(self, traj_data, traj_grp_name, nested_keys=("obs",)):
"""
Processes trajectory data @traj_data and stores them as a new group under @traj_grp_name.

Args:
traj_data (list of dict): Trajectory data, where each entry is a keyword-mapped set of data for a single
sim step
traj_grp_name (str): Name of the trajectory group to store
obs_key (str): Name of key corresponding to observation data in @traj_data. This specific data is
assumed to be its own keyword-mapped dictionary of observations, and will be parsed differently from
the rest of the data
nested_keys (list of str): Name of key(s) corresponding to nested data in @traj_data. This specific data
is assumed to be its own keyword-mapped dictionary of numpy array values, and will be parsed
differently from the rest of the data

Returns:
hdf5.Group: Generated hdf5 group storing the recorded trajectory data
"""
nested_keys = set(nested_keys)
data_grp = self.hdf5_file.require_group("data")
traj_grp = data_grp.create_group(traj_grp_name)
traj_grp.attrs["num_samples"] = len(traj_data)
Expand All @@ -156,11 +161,12 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
# We need to do this because we're not guaranteed to have a full set of keys at every trajectory step; e.g.
# if the first step only has state or observations but no actions
data = defaultdict(list)
data[obs_key] = defaultdict(list)
for key in nested_keys:
data[key] = defaultdict(list)

for step_data in traj_data:
for k, v in step_data.items():
if k == obs_key:
if k in nested_keys:
for mod, step_mod_data in v.items():
data[k][mod].append(step_mod_data)
else:
Expand All @@ -172,7 +178,7 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
continue

# Create datasets for all keys with valid data
if k == obs_key:
if k in nested_keys:
obs_grp = traj_grp.create_group(k)
for mod, traj_mod_data in dat.items():
obs_grp.create_dataset(mod, data=np.stack(traj_mod_data, axis=0))
Expand All @@ -189,7 +195,7 @@ def flush_current_traj(self):
success = self.env.task.success or not self.only_successes
if success and self.hdf5_file is not None:
traj_grp_name = f"demo_{self.traj_count}"
traj_grp = self.process_traj_to_hdf5(self.current_traj_history, traj_grp_name, obs_key="obs")
traj_grp = self.process_traj_to_hdf5(self.current_traj_history, traj_grp_name, nested_keys=["obs"])
self.traj_count += 1
else:
# Remove this demo
Expand Down Expand Up @@ -345,7 +351,7 @@ def _parse_step_data(self, action, obs, reward, terminated, truncated, info):

return step_data

def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
def process_traj_to_hdf5(self, traj_data, traj_grp_name, nested_keys=("obs",)):
# First pad all state values to be the same max (uniform) size
for step_data in traj_data:
state = step_data["state"]
Expand All @@ -354,7 +360,7 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
step_data["state"] = padded_state

# Call super
traj_grp = super().process_traj_to_hdf5(traj_data, traj_grp_name, obs_key)
traj_grp = super().process_traj_to_hdf5(traj_data, traj_grp_name, nested_keys)

# Add in transition info
self.add_metadata(group=traj_grp, name="transitions", data=self.current_transitions)
Expand Down
Loading