Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

repair deterministic playback #174

Merged
merged 4 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 4 additions & 20 deletions robosuite/scripts/collect_human_demonstrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,6 @@ def collect_human_trajectory(env, device, arm, env_configuration):

# Run environment step
env.step(action)

if is_first:
is_first = False

# We grab the initial model xml and state and reload from those so that
# we can support deterministic playback of actions from our demonstrations.
# This is necessary due to rounding issues with the model xml and with
# env.sim.forward(). We also have to do this after the first action is
# applied because the data collector wrapper only starts recording
# after the first action has been played.
initial_mjstate = env.sim.get_state().flatten()
xml_str = env.sim.model.get_xml()
env.reset_from_xml_string(xml_str)
env.sim.reset()
env.sim.set_state_from_flattened(initial_mjstate)
env.sim.forward()

env.render()

# Also break if we complete the task
Expand Down Expand Up @@ -154,10 +137,11 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info):
if len(states) == 0:
continue

# Delete the first actions and the last state. This is because when the DataCollector wrapper
# recorded the states and actions, the states were recorded AFTER playing that action.
# Delete the last state. This is because when the DataCollector wrapper
# recorded the states and actions, the states were recorded AFTER playing that action,
# so we end up with an extra state at the end.
del states[-1]
del actions[0]
assert len(states) == len(actions)

num_eps += 1
ep_data_grp = grp.create_group("demo_{}".format(num_eps))
Expand Down
8 changes: 4 additions & 4 deletions robosuite/scripts/playback_demonstrations_from_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
ignore_done=True,
use_camera_obs=False,
reward_shaping=True,
control_freq=100,
control_freq=20,
)

# list of all demonstrations episodes
Expand All @@ -72,16 +72,16 @@
env.viewer.set_camera(0)

# load the flattened mujoco states
states = f["data/{}/states".format(ep)].value
states = f["data/{}/states".format(ep)][()]

if args.use_actions:

# load the initial state
env.sim.set_state_from_flattened(states[0])
env.sim.forward()

# load the actions and play them back open-loop
joint_torques = f["data/{}/joint_torques".format(ep)].value
actions = np.array(f["data/{}/actions".format(ep)].value)
actions = np.array(f["data/{}/actions".format(ep)][()])
num_actions = actions.shape[0]

for j, action in enumerate(actions):
Expand Down
26 changes: 24 additions & 2 deletions robosuite/wrappers/data_collection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(self, env, directory, collect_freq=1, flush_freq=100):
# remember whether any environment interaction has occurred
self.has_interaction = False

# some variables for remembering the current episode's initial state and model xml
self._current_task_instance_state = None
self._current_task_instance_xml = None

def _start_new_episode(self):
"""
Bookkeeping to do at the start of each new episode.
Expand All @@ -60,6 +64,18 @@ def _start_new_episode(self):
self.t = 0
self.has_interaction = False

# save the task instance (will be saved on the first env interaction)
self._current_task_instance_xml = self.env.sim.model.get_xml()
self._current_task_instance_state = np.array(self.env.sim.get_state().flatten())

# trick for ensuring that we can play MuJoCo demonstrations back
# deterministically by using the recorded actions open loop
self.env.reset_from_xml_string(self._current_task_instance_xml)
self.env.sim.reset()
self.env.sim.set_state_from_flattened(self._current_task_instance_state)
self.env.sim.forward()


def _on_first_interaction(self):
"""
Bookkeeping for first timestep of episode.
Expand All @@ -82,7 +98,12 @@ def _on_first_interaction(self):

# save the model xml
xml_path = os.path.join(self.ep_directory, "model.xml")
save_sim_model(sim=self.sim, fname=xml_path)
with open(xml_path, "w") as f:
f.write(self._current_task_instance_xml)

# save initial state and action
assert len(self.states) == 0
self.states.append(self._current_task_instance_state)

def _flush(self):
"""
Expand Down Expand Up @@ -155,5 +176,6 @@ def close(self):
"""
Override close method in order to flush left over data
"""
self._start_new_episode()
if self.has_interaction:
self._flush()
self.env.close()