You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've noticed some divergence in the recorded and playback states correlated to the flush_freq parameter in DataCollectionWrapper (and have viewed relevant Issues/PRS [1, 2]).
Below is the script adapted from robosuite/demos/demo_collect_and_playback_data.py, in which I collect demonstrations, playback actions, and compare the L2 distance between recorded and actual states.
In my tests, the assertion statement fails on the first state from the second state_file path, which contains the next batch of states of length flush_freq. Specifically, this assertion fails on state_*.npz files created when fail_freq <= 2 * args.timesteps. In other words, in the example below, the assert fails if I set flush_freq=500, but not when flush_freq=501.
import argparse
import os
import random
from glob import glob
import numpy as np
import robosuite as suite
from robosuite.wrappers import DataCollectionWrapper
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
def collect_random_trajectory(env, timesteps=1000):
"""Run a random policy to collect trajectories.
"""
set_seed(0)
env.reset()
dof = env.action_dim
for t in range(timesteps):
action = np.random.randn(dof)
env.step(action)
def playback_trajectory(env, ep_dir):
"""Playback data from an episode.
"""
# first reload the model from the xml
xml_path = os.path.join(ep_dir, "model.xml")
with open(xml_path, "r") as f:
xml = env.edit_model_xml(f.read())
env.reset_from_xml_string(xml)
env.sim.reset()
state_paths = os.path.join(ep_dir, "state_*.npz")
# read states back, load them one by one, and render
t = 0
set_seed(0)
for state_file in sorted(glob(state_paths)):
dic = np.load(state_file, allow_pickle=True)
states = dic["states"]
init_state = states[0]
env.sim.set_state_from_flattened(init_state)
env.sim.forward()
actions = dic["action_infos"]
for idx, act_info in enumerate(actions):
recorded_state = states[idx]
actual_state = env.sim.get_state().flatten()
divergence = np.linalg.norm(recorded_state - actual_state)
assert divergence < 1e-6, f"Divergence: {divergence} at step {t}"
act = act_info["actions"]
env.step(act)
t += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--environment", type=str, default="Lift")
parser.add_argument(
"--robots",
nargs="+",
type=str,
default="Panda",
help="Which robot(s) to use in the env",
)
parser.add_argument("--directory", type=str, default="tmp/")
parser.add_argument("--timesteps", type=int, default=1000)
args = parser.parse_args()
# create original environment
env = suite.make(
args.environment,
robots=args.robots,
ignore_done=True,
use_camera_obs=False,
has_renderer=False,
has_offscreen_renderer=False,
control_freq=20,
)
data_directory = args.directory
# wrap the environment with data collection wrapper
env = DataCollectionWrapper(env, data_directory, flush_freq=200)
# collect some data
print("Collecting some random data...")
collect_random_trajectory(env, timesteps=args.timesteps)
print("Playing back the data...")
data_directory = env.ep_directory
playback_trajectory(env, data_directory)
The text was updated successfully, but these errors were encountered:
jren03
changed the title
flush_freq in DataCollectionWrapper affecting deterministic episode playbackflush_freq in data_collect_wrapper.py affecting deterministic episode playback
Mar 18, 2024
I've noticed some divergence in the recorded and playback states correlated to the
flush_freq
parameter inDataCollectionWrapper
(and have viewed relevant Issues/PRS [1, 2]).Below is the script adapted from
robosuite/demos/demo_collect_and_playback_data.py
, in which I collect demonstrations, playback actions, and compare the L2 distance between recorded and actual states.In my tests, the assertion statement fails on the first state from the second
state_file
path, which contains the next batch of states of lengthflush_freq
. Specifically, this assertion fails onstate_*.npz
files created whenfail_freq <= 2 * args.timesteps
. In other words, in the example below, the assert fails if I setflush_freq=500
, but not whenflush_freq=501
.The text was updated successfully, but these errors were encountered: