From e8bf463ab55b5b35090e9686a5ed70b67220f527 Mon Sep 17 00:00:00 2001 From: amandlek Date: Wed, 13 Jan 2021 22:26:23 -0800 Subject: [PATCH 1/4] repair deterministic playback --- .../scripts/collect_human_demonstrations.py | 24 +++-------------- .../playback_demonstrations_from_hdf5.py | 8 +++--- robosuite/wrappers/data_collection_wrapper.py | 26 +++++++++++++++++-- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/robosuite/scripts/collect_human_demonstrations.py b/robosuite/scripts/collect_human_demonstrations.py index b9d4229b15..0cd6be36c0 100644 --- a/robosuite/scripts/collect_human_demonstrations.py +++ b/robosuite/scripts/collect_human_demonstrations.py @@ -64,23 +64,6 @@ def collect_human_trajectory(env, device, arm, env_configuration): # Run environment step env.step(action) - - if is_first: - is_first = False - - # We grab the initial model xml and state and reload from those so that - # we can support deterministic playback of actions from our demonstrations. - # This is necessary due to rounding issues with the model xml and with - # env.sim.forward(). We also have to do this after the first action is - # applied because the data collector wrapper only starts recording - # after the first action has been played. - initial_mjstate = env.sim.get_state().flatten() - xml_str = env.sim.model.get_xml() - env.reset_from_xml_string(xml_str) - env.sim.reset() - env.sim.set_state_from_flattened(initial_mjstate) - env.sim.forward() - env.render() # Also break if we complete the task @@ -154,10 +137,11 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info): if len(states) == 0: continue - # Delete the first actions and the last state. This is because when the DataCollector wrapper - # recorded the states and actions, the states were recorded AFTER playing that action. + # Delete the last state. This is because when the DataCollector wrapper + # recorded the states and actions, the states were recorded AFTER playing that action, + # so we end up with an extra state at the end. del states[-1] - del actions[0] + assert len(states) == len(actions) num_eps += 1 ep_data_grp = grp.create_group("demo_{}".format(num_eps)) diff --git a/robosuite/scripts/playback_demonstrations_from_hdf5.py b/robosuite/scripts/playback_demonstrations_from_hdf5.py index 6443c4a10a..3dded591fd 100644 --- a/robosuite/scripts/playback_demonstrations_from_hdf5.py +++ b/robosuite/scripts/playback_demonstrations_from_hdf5.py @@ -50,7 +50,7 @@ ignore_done=True, use_camera_obs=False, reward_shaping=True, - control_freq=100, + control_freq=20, ) # list of all demonstrations episodes @@ -72,16 +72,16 @@ env.viewer.set_camera(0) # load the flattened mujoco states - states = f["data/{}/states".format(ep)].value + states = f["data/{}/states".format(ep)][()] if args.use_actions: # load the initial state env.sim.set_state_from_flattened(states[0]) + env.sim.forward() # load the actions and play them back open-loop - joint_torques = f["data/{}/joint_torques".format(ep)].value - actions = np.array(f["data/{}/actions".format(ep)].value) + actions = np.array(f["data/{}/actions".format(ep)][()]) num_actions = actions.shape[0] for j, action in enumerate(actions): diff --git a/robosuite/wrappers/data_collection_wrapper.py b/robosuite/wrappers/data_collection_wrapper.py index 4aa57eebd0..d28b1c5edf 100644 --- a/robosuite/wrappers/data_collection_wrapper.py +++ b/robosuite/wrappers/data_collection_wrapper.py @@ -47,6 +47,10 @@ def __init__(self, env, directory, collect_freq=1, flush_freq=100): # remember whether any environment interaction has occurred self.has_interaction = False + # some variables for remembering the current episode's initial state and model xml + self._current_task_instance_state = None + self._current_task_instance_xml = None + def _start_new_episode(self): """ Bookkeeping to do at the start of each new episode. @@ -60,6 +64,18 @@ def _start_new_episode(self): self.t = 0 self.has_interaction = False + # save the task instance (will be saved on the first env interaction) + self._current_task_instance_xml = self.env.sim.model.get_xml() + self._current_task_instance_state = np.array(self.env.sim.get_state().flatten()) + + # trick for ensuring that we can play MuJoCo demonstrations back + # deterministically by using the recorded actions open loop + self.env.reset_from_xml_string(self._current_task_instance_xml) + self.env.sim.reset() + self.env.sim.set_state_from_flattened(self._current_task_instance_state) + self.env.sim.forward() + + def _on_first_interaction(self): """ Bookkeeping for first timestep of episode. @@ -82,7 +98,12 @@ def _on_first_interaction(self): # save the model xml xml_path = os.path.join(self.ep_directory, "model.xml") - save_sim_model(sim=self.sim, fname=xml_path) + with open(xml_path, "w") as f: + f.write(self._current_task_instance_xml) + + # save initial state and action + assert len(self.states) == 0 + self.states.append(self._current_task_instance_state) def _flush(self): """ @@ -155,5 +176,6 @@ def close(self): """ Override close method in order to flush left over data """ - self._start_new_episode() + if self.has_interaction: + self._flush() self.env.close() From 109b4633c72f1b080a984067a0aa9b08914e9f9b Mon Sep 17 00:00:00 2001 From: amandlek Date: Thu, 14 Jan 2021 12:06:31 -0800 Subject: [PATCH 2/4] Update demo_sampler_wrapper.py --- robosuite/wrappers/demo_sampler_wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/robosuite/wrappers/demo_sampler_wrapper.py b/robosuite/wrappers/demo_sampler_wrapper.py index 36cc89d085..5ac534abef 100644 --- a/robosuite/wrappers/demo_sampler_wrapper.py +++ b/robosuite/wrappers/demo_sampler_wrapper.py @@ -212,7 +212,7 @@ def _uniform_sample(self): ep_ind = random.choice(self.demo_list) # select a flattened mujoco state uniformly from this episode - states = self.demo_file["data/{}/states".format(ep_ind)].value + states = self.demo_file["data/{}/states".format(ep_ind)][()] state = random.choice(states) if self.need_xml: @@ -239,7 +239,7 @@ def _reverse_sample_open_loop(self): ep_ind = random.choice(self.demo_list) # sample uniformly in a window that grows backwards from the end of the demos - states = self.demo_file["data/{}/states".format(ep_ind)].value + states = self.demo_file["data/{}/states".format(ep_ind)][()] eps_len = states.shape[0] index = np.random.randint(max(eps_len - self.open_loop_window_size, 0), eps_len) state = states[index] @@ -276,7 +276,7 @@ def _forward_sample_open_loop(self): ep_ind = random.choice(self.demo_list) # sample uniformly in a window that grows forwards from the beginning of the demos - states = self.demo_file["data/{}/states".format(ep_ind)].value + states = self.demo_file["data/{}/states".format(ep_ind)][()] eps_len = states.shape[0] index = np.random.randint(0, min(self.open_loop_window_size, eps_len)) state = states[index] From 4e64c45acf30c4b29f153aa55d8830c991590272 Mon Sep 17 00:00:00 2001 From: amandlek Date: Thu, 14 Jan 2021 12:22:56 -0800 Subject: [PATCH 3/4] change assert to warning for playback actions --- robosuite/scripts/playback_demonstrations_from_hdf5.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/robosuite/scripts/playback_demonstrations_from_hdf5.py b/robosuite/scripts/playback_demonstrations_from_hdf5.py index 3dded591fd..1c4f3dfb70 100644 --- a/robosuite/scripts/playback_demonstrations_from_hdf5.py +++ b/robosuite/scripts/playback_demonstrations_from_hdf5.py @@ -91,7 +91,8 @@ if j < num_actions - 1: # ensure that the actions deterministically lead to the same recorded states state_playback = env.sim.get_state().flatten() - assert(np.all(np.equal(states[j + 1], state_playback))) + if not np.all(np.equal(states[j + 1], state_playback)): + print("warning! playback diverged for ep {} at step {}".format(ep, j)) else: From d5d6c29d640e1ea27f0378a42efa56ffe584a881 Mon Sep 17 00:00:00 2001 From: Yuke Zhu Date: Thu, 14 Jan 2021 16:37:10 -0600 Subject: [PATCH 4/4] update warning message --- robosuite/scripts/playback_demonstrations_from_hdf5.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/robosuite/scripts/playback_demonstrations_from_hdf5.py b/robosuite/scripts/playback_demonstrations_from_hdf5.py index 1c4f3dfb70..82d2b3230c 100644 --- a/robosuite/scripts/playback_demonstrations_from_hdf5.py +++ b/robosuite/scripts/playback_demonstrations_from_hdf5.py @@ -92,7 +92,8 @@ # ensure that the actions deterministically lead to the same recorded states state_playback = env.sim.get_state().flatten() if not np.all(np.equal(states[j + 1], state_playback)): - print("warning! playback diverged for ep {} at step {}".format(ep, j)) + err = np.linalg.norm(states[j + 1] - state_playback) + print(f"[warning] playback diverged by {err:.2f} for ep {ep} at step {j}") else: