Skip to content

Commit

Permalink
repair deterministic playback (#174)
Browse files Browse the repository at this point in the history
* repair deterministic playback

Co-authored-by: Yuke Zhu <[email protected]>
  • Loading branch information
amandlek and yukezhu authored Jan 15, 2021
1 parent 779a3b7 commit a18f1a7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 30 deletions.
24 changes: 4 additions & 20 deletions robosuite/scripts/collect_human_demonstrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,6 @@ def collect_human_trajectory(env, device, arm, env_configuration):

# Run environment step
env.step(action)

if is_first:
is_first = False

# We grab the initial model xml and state and reload from those so that
# we can support deterministic playback of actions from our demonstrations.
# This is necessary due to rounding issues with the model xml and with
# env.sim.forward(). We also have to do this after the first action is
# applied because the data collector wrapper only starts recording
# after the first action has been played.
initial_mjstate = env.sim.get_state().flatten()
xml_str = env.sim.model.get_xml()
env.reset_from_xml_string(xml_str)
env.sim.reset()
env.sim.set_state_from_flattened(initial_mjstate)
env.sim.forward()

env.render()

# Also break if we complete the task
Expand Down Expand Up @@ -154,10 +137,11 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info):
if len(states) == 0:
continue

# Delete the first actions and the last state. This is because when the DataCollector wrapper
# recorded the states and actions, the states were recorded AFTER playing that action.
# Delete the last state. This is because when the DataCollector wrapper
# recorded the states and actions, the states were recorded AFTER playing that action,
# so we end up with an extra state at the end.
del states[-1]
del actions[0]
assert len(states) == len(actions)

num_eps += 1
ep_data_grp = grp.create_group("demo_{}".format(num_eps))
Expand Down
12 changes: 7 additions & 5 deletions robosuite/scripts/playback_demonstrations_from_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
ignore_done=True,
use_camera_obs=False,
reward_shaping=True,
control_freq=100,
control_freq=20,
)

# list of all demonstrations episodes
Expand All @@ -72,16 +72,16 @@
env.viewer.set_camera(0)

# load the flattened mujoco states
states = f["data/{}/states".format(ep)].value
states = f["data/{}/states".format(ep)][()]

if args.use_actions:

# load the initial state
env.sim.set_state_from_flattened(states[0])
env.sim.forward()

# load the actions and play them back open-loop
joint_torques = f["data/{}/joint_torques".format(ep)].value
actions = np.array(f["data/{}/actions".format(ep)].value)
actions = np.array(f["data/{}/actions".format(ep)][()])
num_actions = actions.shape[0]

for j, action in enumerate(actions):
Expand All @@ -91,7 +91,9 @@
if j < num_actions - 1:
# ensure that the actions deterministically lead to the same recorded states
state_playback = env.sim.get_state().flatten()
assert(np.all(np.equal(states[j + 1], state_playback)))
if not np.all(np.equal(states[j + 1], state_playback)):
err = np.linalg.norm(states[j + 1] - state_playback)
print(f"[warning] playback diverged by {err:.2f} for ep {ep} at step {j}")

else:

Expand Down
26 changes: 24 additions & 2 deletions robosuite/wrappers/data_collection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(self, env, directory, collect_freq=1, flush_freq=100):
# remember whether any environment interaction has occurred
self.has_interaction = False

# some variables for remembering the current episode's initial state and model xml
self._current_task_instance_state = None
self._current_task_instance_xml = None

def _start_new_episode(self):
"""
Bookkeeping to do at the start of each new episode.
Expand All @@ -60,6 +64,18 @@ def _start_new_episode(self):
self.t = 0
self.has_interaction = False

# save the task instance (will be saved on the first env interaction)
self._current_task_instance_xml = self.env.sim.model.get_xml()
self._current_task_instance_state = np.array(self.env.sim.get_state().flatten())

# trick for ensuring that we can play MuJoCo demonstrations back
# deterministically by using the recorded actions open loop
self.env.reset_from_xml_string(self._current_task_instance_xml)
self.env.sim.reset()
self.env.sim.set_state_from_flattened(self._current_task_instance_state)
self.env.sim.forward()


def _on_first_interaction(self):
"""
Bookkeeping for first timestep of episode.
Expand All @@ -82,7 +98,12 @@ def _on_first_interaction(self):

# save the model xml
xml_path = os.path.join(self.ep_directory, "model.xml")
save_sim_model(sim=self.sim, fname=xml_path)
with open(xml_path, "w") as f:
f.write(self._current_task_instance_xml)

# save initial state and action
assert len(self.states) == 0
self.states.append(self._current_task_instance_state)

def _flush(self):
"""
Expand Down Expand Up @@ -155,5 +176,6 @@ def close(self):
"""
Override close method in order to flush left over data
"""
self._start_new_episode()
if self.has_interaction:
self._flush()
self.env.close()
6 changes: 3 additions & 3 deletions robosuite/wrappers/demo_sampler_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _uniform_sample(self):
ep_ind = random.choice(self.demo_list)

# select a flattened mujoco state uniformly from this episode
states = self.demo_file["data/{}/states".format(ep_ind)].value
states = self.demo_file["data/{}/states".format(ep_ind)][()]
state = random.choice(states)

if self.need_xml:
Expand All @@ -239,7 +239,7 @@ def _reverse_sample_open_loop(self):
ep_ind = random.choice(self.demo_list)

# sample uniformly in a window that grows backwards from the end of the demos
states = self.demo_file["data/{}/states".format(ep_ind)].value
states = self.demo_file["data/{}/states".format(ep_ind)][()]
eps_len = states.shape[0]
index = np.random.randint(max(eps_len - self.open_loop_window_size, 0), eps_len)
state = states[index]
Expand Down Expand Up @@ -276,7 +276,7 @@ def _forward_sample_open_loop(self):
ep_ind = random.choice(self.demo_list)

# sample uniformly in a window that grows forwards from the beginning of the demos
states = self.demo_file["data/{}/states".format(ep_ind)].value
states = self.demo_file["data/{}/states".format(ep_ind)][()]
eps_len = states.shape[0]
index = np.random.randint(0, min(self.open_loop_window_size, eps_len))
state = states[index]
Expand Down

0 comments on commit a18f1a7

Please sign in to comment.