Skip to content

Commit

Permalink
[Performance] Faster PrioritizedSliceSampler._padded_indices (#2433)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Sep 12, 2024
1 parent fb9cc2c commit 361b763
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from torchrl._extension import EXTENSION_WARNING

from torchrl._utils import _replace_last, implement_for, logger
from torchrl._utils import _replace_last, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int, unravel_index

Expand Down Expand Up @@ -1842,28 +1842,31 @@ def mark_update(
) -> None:
return PrioritizedSampler.mark_update(self, index, storage=storage)

@implement_for("torch", "2.4")
def _padded_indices(self, shapes, arange) -> torch.Tensor:
# this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
# tensor([[ 0, 1, 2, 3, 4],
# [-1, -1, 5, 6, 7],
# [-1, 8, 9, 10, 11]])
# where the -1 items on the left are padded values
st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0))
nt = torch._nested_view_from_buffer(
arange.flip(0).contiguous(), shapes.flip(0), st, off
num_groups = shapes.shape[0]
max_group_len = shapes.max()
pad_lengths = max_group_len - shapes

# Get all the start and end indices within arange for each group
group_ends = shapes.cumsum(0)
group_starts = torch.empty_like(group_ends)
group_starts[0] = 0
group_starts[1:] = group_ends[:-1]
pad = torch.empty(
(num_groups, max_group_len), dtype=arange.dtype, device=arange.device
)
pad = nt.to_padded_tensor(-1).flip(-1).flip(0)
return pad
for pad_row, group_start, group_end, pad_len in zip(
pad, group_starts, group_ends, pad_lengths
):
pad_row[:pad_len] = -1
pad_row[pad_len:] = arange[group_start:group_end]

@implement_for("torch", None, "2.4")
def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
arange = arange.flip(0).split(shapes.flip(0).squeeze().unbind())
return (
torch.nn.utils.rnn.pad_sequence(arange, batch_first=True, padding_value=-1)
.flip(-1)
.flip(0)
)
return pad

def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
preceding_stop_idx = self._cache.get("preceding_stop_idx")
Expand Down

1 comment on commit 361b763

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 361b763 Previous: fb9cc2c Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 2094.0166244139073 iter/sec (stddev: 0.00456845021246647) 4283.956237897829 iter/sec (stddev: 0.000022894003609446883) 2.05

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.