diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 869ea5cdae3..7bc37c3e591 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -982,22 +982,20 @@ def _find_start_stop_traj( # faster end = trajectory[:-1] != trajectory[1:] - end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0) + if not at_capacity: + end = torch.cat([end, torch.ones_like(end[:1])], 0) + else: + end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0) length = trajectory.shape[0] else: - # TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True - # We presume that not done at the end means that the traj spans across end and beginning of storage length = end.shape[0] + if not at_capacity: + end = end.clone() + end[length - 1] = True + ndim = end.ndim - if not at_capacity: - end = torch.index_fill( - end, - index=torch.tensor(-1, device=end.device, dtype=torch.long), - dim=0, - value=1, - ) - else: + if at_capacity: # we must have at least one end by traj to individuate trajectories # so if no end can be found we set it manually if cursor is not None: @@ -1019,7 +1017,6 @@ def _find_start_stop_traj( mask = ~end.any(0, True) mask = torch.cat([torch.zeros_like(end[:-1]), mask]) end = torch.masked_fill(mask, end, 1) - ndim = end.ndim if ndim == 0: raise RuntimeError( "Expected the end-of-trajectory signal to be at least 1-dimensional." @@ -1109,57 +1106,63 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length): result[:, 0] = result[:, 0] % storage_length return result + @torch.no_grad() def _get_stop_and_length(self, storage, fallback=True): if self.cache_values and "stop-and-length" in self._cache: return self._cache.get("stop-and-length") + current_storage = storage[:] if self._fetch_traj: # We first try with the traj_key + + if isinstance(storage, TensorStorage): + key = self._used_traj_key + else: + key = self.traj_key try: - if isinstance(storage, TensorStorage): - trajectory = storage[:][self._used_traj_key] - else: - try: - trajectory = storage[:][self.traj_key] - except Exception: - raise RuntimeError( - "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." - ) - vals = self._find_start_stop_traj( - trajectory=trajectory, - at_capacity=storage._is_full, - cursor=getattr(storage, "_last_cursor", None), + trajectory = current_storage.get(key, default=None) + except Exception: + # eg, ListStorage + raise RuntimeError( + "Could not get a tensordict out of the storage, " + "which is required for SliceSampler to compute the trajectories." ) - if self.cache_values: - self._cache["stop-and-length"] = vals - return vals - except KeyError: + if trajectory is None: if fallback: self._fetch_traj = False return self._get_stop_and_length(storage, fallback=False) - raise - + raise KeyError(f"Coulnd't find key={key} in storage.") + vals = self._find_start_stop_traj( + trajectory=trajectory, + at_capacity=storage._is_full, + cursor=getattr(storage, "_last_cursor", None), + ) + if self.cache_values: + self._cache["stop-and-length"] = vals + return vals else: try: - try: - done = storage[:][self.end_key] - except Exception: - raise RuntimeError( - "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." - ) - vals = self._find_start_stop_traj( - end=done.squeeze()[: len(storage)], - at_capacity=storage._is_full, - cursor=getattr(storage, "_last_cursor", None), + done = current_storage.get(self.end_key, None) + except Exception: + # eg, ListStorage + raise RuntimeError( + "Could not get a tensordict out of the storage, " + "which is required for SliceSampler to compute the trajectories." ) - if self.cache_values: - self._cache["stop-and-length"] = vals - return vals - except KeyError: + if done is None: if fallback: self._fetch_traj = True return self._get_stop_and_length(storage, fallback=False) - raise + raise KeyError(f"Couldn't find key={self.end_key} in storage.") + + vals = self._find_start_stop_traj( + end=done.squeeze(), + at_capacity=storage._is_full, + cursor=getattr(storage, "_last_cursor", None), + ) + if self.cache_values: + self._cache["stop-and-length"] = vals + return vals def _adjusted_batch_size(self, batch_size): if self.num_slices is not None: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index b914c52b338..3c949adaf50 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -901,9 +901,7 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.to(self.device) - out = out.expand(max_size_along_dim0(data.shape)) - out = out.clone() - out = out.zero_() + out = torch.empty_like(out.expand(max_size_along_dim0(data.shape))) else: # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = tree_map(