Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flush_freq in data_collect_wrapper.py affecting deterministic episode playback #462

Open
jren03 opened this issue Mar 18, 2024 · 0 comments
Assignees

Comments

@jren03
Copy link

jren03 commented Mar 18, 2024

I've noticed some divergence in the recorded and playback states correlated to the flush_freq parameter in DataCollectionWrapper (and have viewed relevant Issues/PRS [1, 2]).

Below is the script adapted from robosuite/demos/demo_collect_and_playback_data.py, in which I collect demonstrations, playback actions, and compare the L2 distance between recorded and actual states.

In my tests, the assertion statement fails on the first state from the second state_file path, which contains the next batch of states of length flush_freq. Specifically, this assertion fails on state_*.npz files created when fail_freq <= 2 * args.timesteps. In other words, in the example below, the assert fails if I set flush_freq=500, but not when flush_freq=501.

import argparse
import os
import random
from glob import glob

import numpy as np

import robosuite as suite
from robosuite.wrappers import DataCollectionWrapper


def set_seed(seed: int):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)


def collect_random_trajectory(env, timesteps=1000):
    """Run a random policy to collect trajectories.
    """
    set_seed(0)
    env.reset()
    dof = env.action_dim

    for t in range(timesteps):
        action = np.random.randn(dof)
        env.step(action)


def playback_trajectory(env, ep_dir):
    """Playback data from an episode.
    """

    # first reload the model from the xml
    xml_path = os.path.join(ep_dir, "model.xml")
    with open(xml_path, "r") as f:
        xml = env.edit_model_xml(f.read())
        env.reset_from_xml_string(xml)
    env.sim.reset()

    state_paths = os.path.join(ep_dir, "state_*.npz")

    # read states back, load them one by one, and render
    t = 0
    set_seed(0)
    for state_file in sorted(glob(state_paths)):
        dic = np.load(state_file, allow_pickle=True)
        states = dic["states"]
        init_state = states[0]
        env.sim.set_state_from_flattened(init_state)
        env.sim.forward()

        actions = dic["action_infos"]
        for idx, act_info in enumerate(actions):
            recorded_state = states[idx]
            actual_state = env.sim.get_state().flatten()
            divergence = np.linalg.norm(recorded_state - actual_state)

            assert divergence < 1e-6, f"Divergence: {divergence} at step {t}"

            act = act_info["actions"]
            env.step(act)
            t += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--environment", type=str, default="Lift")
    parser.add_argument(
        "--robots",
        nargs="+",
        type=str,
        default="Panda",
        help="Which robot(s) to use in the env",
    )
    parser.add_argument("--directory", type=str, default="tmp/")
    parser.add_argument("--timesteps", type=int, default=1000)
    args = parser.parse_args()

    # create original environment
    env = suite.make(
        args.environment,
        robots=args.robots,
        ignore_done=True,
        use_camera_obs=False,
        has_renderer=False,
        has_offscreen_renderer=False,
        control_freq=20,
    )
    data_directory = args.directory

    # wrap the environment with data collection wrapper
    env = DataCollectionWrapper(env, data_directory, flush_freq=200)

    # collect some data
    print("Collecting some random data...")
    collect_random_trajectory(env, timesteps=args.timesteps)

    print("Playing back the data...")
    data_directory = env.ep_directory
    playback_trajectory(env, data_directory)
@jren03 jren03 changed the title flush_freq in DataCollectionWrapper affecting deterministic episode playback flush_freq in data_collect_wrapper.py affecting deterministic episode playback Mar 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants