Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pass replay buffers to SyncDataCollector #2384

Merged
merged 14 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,8 +2585,15 @@ def test_unique_traj_sync(self, cat_results):
buffer.extend(d)
assert c._use_buffers
traj_ids = buffer[:].get(("collector", "traj_ids"))
# check that we have as many trajs as expected (no skip)
assert traj_ids.unique().numel() == traj_ids.max() + 1
# Ideally, we'd like that (sorted_traj.values == sorted_traj.indices).all()
# but in practice, one env can reach the end of the rollout and do a reset
# (which we don't want to prevent) and increment the global traj count,
# when the others have not finished yet. In that case, this traj number will never
# appear.
# sorted_traj = traj_ids.unique().sort()
# assert (sorted_traj.values == sorted_traj.indices).all()
# assert traj_ids.unique().numel() == traj_ids.max() + 1

# check that trajs are not overlapping
if stack_results:
sets = [
Expand Down Expand Up @@ -2751,6 +2758,79 @@ def test_async(self, use_buffers):
del collector


class TestCollectorRB:
@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_sync(self):
env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
env.set_seed(0)
rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)

for c in collector:
assert c is None
rb.sample()
rbdata0 = rb[:].clone()
collector.shutdown()
if not env.is_closed:
env.close()
del collector, env

env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
env.set_seed(0)
rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
collector = SyncDataCollector(
env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16
)
torch.manual_seed(0)

for i, c in enumerate(collector):
rb.extend(c)
torch.testing.assert_close(
rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"]
)
assert c is not None
rb.sample()

rbdata1 = rb[:].clone()
collector.shutdown()
if not env.is_closed:
env.close()
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_multisync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
assert c is None
assert len(rb) == pred_len
collector.shutdown()
assert len(rb) == 256


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 4 additions & 8 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,13 +2064,16 @@ def exec_multiproc_rb(
init=True,
writer_type=TensorDictRoundRobinWriter,
sampler_type=RandomSampler,
device=None,
):
rb = TensorDictReplayBuffer(
storage=storage_type(21), writer=writer_type(), sampler=sampler_type()
)
if init:
td = TensorDict(
{"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10]
{"a": torch.zeros(10), "next": {"reward": torch.ones(10)}},
[10],
device=device,
)
rb.extend(td)
q0 = mp.Queue(1)
Expand Down Expand Up @@ -2098,13 +2101,6 @@ def test_error_list(self):
with pytest.raises(RuntimeError, match="Cannot share a storage of type"):
self.exec_multiproc_rb(storage_type=ListStorage)

def test_error_nonshared(self):
# non shared tensor storage cannot be shared
with pytest.raises(
RuntimeError, match="The storage must be place in shared memory"
):
self.exec_multiproc_rb(storage_type=LazyTensorStorage)

def test_error_maxwriter(self):
# TensorDictMaxValueWriter cannot be shared
with pytest.raises(RuntimeError, match="cannot be shared between processes"):
Expand Down
Loading
Loading