-
Notifications
You must be signed in to change notification settings - Fork 319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Faster PrioritizedSliceSampler._padded_indices
#2433
[Performance] Faster PrioritizedSliceSampler._padded_indices
#2433
Conversation
|
||
@implement_for("torch", None, "2.4") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my machine, both of the previous implementations had roughly the same performance
) | ||
pad = nt.to_padded_tensor(-1).flip(-1).flip(0) | ||
return pad | ||
for pad_row, group_start, group_end, pad_len in zip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to be able to get rid of this for loop, but I don't see a good way to do it. I think indexing different sized ranges under each row of a tensor is only possible if you list out all the indices, like for torch.index_select
.
If torch.nn.utils.rnn.pad_sequence
supported left-padding, we could just use that, but it doesn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, I did try getting rid of the for loop by listing out all the indices and using index_copy_
like so:
shapes = shapes.flatten()
num_groups = shapes.shape[0]
max_group_len = shapes.max()
pad_lengths = max_group_len - shapes
pad_before_groups = pad_lengths.cumsum(0)
pad_before_indices = torch.repeat_interleave(pad_before_groups, shapes)
indices = torch.arange(arange.shape[0], dtype=arange.dtype, device=arange.device) + pad_before_indices
p = torch.full(
(num_groups * max_group_len,),
-1,
dtype=arange.dtype,
device=arange.device
)
p.index_copy_(0, indices, arange)
pad = p.reshape((num_groups, max_group_len))
It worked, but it gave very similar performance to the old implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, torch.nn.utils.rnn.pad_sequence
left padding was mentioned during a review not long ago
cc @mikaylagawarecki
def _padded_indices(self, shapes, arange) -> torch.Tensor: | ||
# this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g. | ||
# tensor([[ 0, 1, 2, 3, 4], | ||
# [-1, -1, 5, 6, 7], | ||
# [-1, 8, 9, 10, 11]]) | ||
# where the -1 items on the left are padded values | ||
st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0)) | ||
nt = torch._nested_view_from_buffer( | ||
arange.flip(0).contiguous(), shapes.flip(0), st, off |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the main thing that the new implementation improves upon is that it doesn't need to do these torch.flip
operations, and the time it saves on those seems to outweigh the overhead of the for loop
Something is interesting about the cProfile results I got. Before this PR,
I would guess that the reason is that |
One other thing, I tried to measure the cuda performance, but changing the device in my script #2431 (comment) to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for taking care of this!
) | ||
pad = nt.to_padded_tensor(-1).flip(-1).flip(0) | ||
return pad | ||
for pad_row, group_start, group_end, pad_len in zip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, torch.nn.utils.rnn.pad_sequence
left padding was mentioned during a review not long ago
cc @mikaylagawarecki
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for taking care of this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for taking care of this!
Description
Speeds up
PrioritizedSliceSampler._padded_indices
by about 2x.Running the performance script given in #2431 (comment), my machine gives the following:
This is a speedup of (22.235 / 7.804) = 2.8x
Although sometimes, the runtime per
sample
call reaches up to 11.8 ms, a speedup of ~1.8x.Motivation and Context
close #2431
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!