-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf(rust, python) Improve rolling min and max for nonulls #9277
Merged
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,252 +18,164 @@ impl<'a, T: NativeType> RollingAggWindowNoNulls<'a, T> for SortedMinMax<'a, T> { | |
} | ||
} | ||
|
||
#[inline] | ||
unsafe fn get_min_and_ix<T>(slice: &[T], start: usize, end: usize) -> Option<(usize, &T)> | ||
where | ||
T: NativeType + IsFloat + PartialOrd, | ||
{ | ||
slice | ||
.get_unchecked(start..end) | ||
.iter() | ||
.enumerate() | ||
.rev() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we iterate in reverse to ensure we get the latest index? If so, could you add a comment about that in this function. |
||
.min_by(|&a, &b| compare_fn_nan_min(a.1, b.1)) | ||
} | ||
|
||
pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> { | ||
slice: &'a [T], | ||
min: T, | ||
min_ix: usize, | ||
last_start: usize, | ||
last_end: usize, | ||
} | ||
|
||
impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> for MinWindow<'a, T> { | ||
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self { | ||
let min = *slice[start..end] | ||
let (ix, min) = slice[start..end] | ||
.iter() | ||
.min_by(|a, b| compare_fn_nan_min(*a, *b)) | ||
.unwrap_or(&slice[start]); | ||
.enumerate() | ||
.rev() | ||
.min_by(|&a, &b| compare_fn_nan_min(a.1, b.1)) | ||
.unwrap_or((0, &slice[start])); | ||
Self { | ||
slice, | ||
min, | ||
min: *min, | ||
min_ix: start + ix, | ||
last_start: start, | ||
last_end: end, | ||
} | ||
} | ||
|
||
unsafe fn update(&mut self, start: usize, end: usize) -> T { | ||
// recompute min | ||
if start >= self.last_end | ||
// the window only got smaller | ||
|| end == self.last_end | ||
{ | ||
self.min = *self | ||
.slice | ||
.get_unchecked(start..end) | ||
.iter() | ||
.min_by(|a, b| compare_fn_nan_min(*a, *b)) | ||
.unwrap_or(self.slice.get_unchecked(start)); | ||
self.last_start = start; // Don't care where the last one started | ||
let old_last_end = self.last_end; // But we need this | ||
self.last_end = end; | ||
|
||
self.last_start = start; | ||
self.last_end = end; | ||
let entering_start = std::cmp::max(old_last_end, start); | ||
let entering = get_min_and_ix(self.slice, entering_start, end); | ||
|
||
if entering.is_some_and(|em| compare_fn_nan_min(&self.min, em.1).is_ge()) { | ||
// If the entering min <= the current min return early, since no value in the overlap can be smaller than either. | ||
self.min = *entering.unwrap().1; | ||
self.min_ix = entering_start + entering.unwrap().0; | ||
return self.min; | ||
} else if self.min_ix >= start { | ||
// If the entering min isn't the smallest but the current min is between start and end we can still ignore the overlap | ||
return self.min; | ||
} | ||
|
||
let mut recompute_min = false; | ||
// remove elements that should leave the window | ||
for idx in self.last_start..start { | ||
// safety | ||
// we are in bounds | ||
let leaving_value = self.slice.get_unchecked(idx); | ||
|
||
// if the leaving value is the | ||
// max value, we need to recompute the max. | ||
if matches!( | ||
compare_fn_nan_min(leaving_value, &self.min), | ||
Ordering::Equal | ||
) { | ||
recompute_min = true; | ||
break; | ||
} | ||
} | ||
|
||
let entering_min = self | ||
.slice | ||
.get_unchecked(self.last_end..end) | ||
.iter() | ||
.min_by(|a, b| compare_fn_nan_min(*a, *b)) | ||
.unwrap_or(self.slice.get_unchecked(std::cmp::min( | ||
self.last_start, | ||
self.last_end.saturating_sub(1), | ||
))); | ||
|
||
if recompute_min { | ||
match compare_fn_nan_min(&self.min, entering_min) { | ||
// do nothing | ||
Ordering::Equal => {} | ||
// leaving < entering | ||
Ordering::Less => { | ||
// leaving value could be the smallest, we might need to recompute | ||
|
||
// check the values in between the window we did not yet | ||
// compute to find the max there. We compare that with the `entering_max` | ||
// if any value is equal to equal to `self.max` of previous window we can break | ||
// early | ||
let mut min_in_between = self.slice.get_unchecked(start); | ||
for idx in (start + 1)..self.last_end { | ||
// safety | ||
// we are in bounds | ||
let value = self.slice.get_unchecked(idx); | ||
|
||
if matches!(compare_fn_nan_min(value, min_in_between), Ordering::Less) { | ||
min_in_between = value; | ||
} | ||
|
||
// the min is also in the in between values | ||
if matches!(compare_fn_nan_min(value, &self.min), Ordering::Equal) { | ||
self.last_start = start; | ||
self.last_end = end; | ||
return self.min; | ||
} | ||
} | ||
|
||
if matches!( | ||
compare_fn_nan_min(min_in_between, entering_min), | ||
Ordering::Less | ||
) { | ||
self.min = *min_in_between | ||
} else { | ||
self.min = *entering_min | ||
} | ||
} | ||
// leaving > entering | ||
Ordering::Greater => { | ||
if matches!(compare_fn_nan_min(entering_min, &self.min), Ordering::Less) { | ||
self.min = *entering_min | ||
} | ||
// Otherwise get the min of the overlapping window and the entering min | ||
match (get_min_and_ix(self.slice, start, old_last_end), entering) { | ||
(Some(pm), Some(em)) => { | ||
if compare_fn_nan_min(pm.1, em.1).is_ge() { | ||
self.min = *em.1; | ||
self.min_ix = entering_start + em.0; | ||
} else { | ||
self.min = *pm.1; | ||
self.min_ix = start + pm.0; | ||
} | ||
} | ||
} else if matches!(compare_fn_nan_min(entering_min, &self.min), Ordering::Less) { | ||
self.min = *entering_min | ||
(Some(pm), None) => { | ||
self.min = *pm.1; | ||
self.min_ix = start + pm.0; | ||
} | ||
(None, Some(em)) => { | ||
self.min = *em.1; | ||
self.min_ix = entering_start + em.0; | ||
} | ||
// We shouldn't reach this, but it means | ||
(None, None) => {} | ||
} | ||
|
||
self.last_start = start; | ||
self.last_end = end; | ||
self.min | ||
} | ||
} | ||
|
||
#[inline] | ||
unsafe fn get_max_and_ix<T>(slice: &[T], start: usize, end: usize) -> Option<(usize, &T)> | ||
where | ||
T: NativeType + IsFloat + PartialOrd, | ||
{ | ||
slice | ||
.get_unchecked(start..end) | ||
.iter() | ||
.enumerate() | ||
.max_by(|&a, &b| compare_fn_nan_max(a.1, b.1)) | ||
} | ||
|
||
pub struct MaxWindow<'a, T: NativeType> { | ||
slice: &'a [T], | ||
max: T, | ||
max_ix: usize, | ||
last_start: usize, | ||
last_end: usize, | ||
} | ||
|
||
impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> for MaxWindow<'a, T> { | ||
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self { | ||
let max = *slice[start..end] | ||
let (ix, max) = slice[start..end] | ||
.iter() | ||
.max_by(|a, b| compare_fn_nan_max(*a, *b)) | ||
.unwrap_or(&slice[start]); | ||
.enumerate() | ||
.max_by(|&a, &b| compare_fn_nan_max(a.1, b.1)) | ||
.unwrap_or((0, &slice[start])); | ||
Self { | ||
slice, | ||
max, | ||
max: *max, | ||
max_ix: start + ix, | ||
last_start: start, | ||
last_end: end, | ||
} | ||
} | ||
|
||
unsafe fn update(&mut self, start: usize, end: usize) -> T { | ||
// recompute max | ||
if start >= self.last_end | ||
// the window only got smaller | ||
|| end == self.last_end | ||
{ | ||
self.max = *self | ||
.slice | ||
.get_unchecked(start..end) | ||
.iter() | ||
.max_by(|a, b| compare_fn_nan_max(*a, *b)) | ||
.unwrap_or(self.slice.get_unchecked(start)); | ||
self.last_start = start; // Don't care where the last one started | ||
let old_last_end = self.last_end; // But we need this | ||
self.last_end = end; | ||
|
||
self.last_start = start; | ||
self.last_end = end; | ||
let entering_start = std::cmp::max(old_last_end, start); | ||
let entering = get_max_and_ix(self.slice, entering_start, end); | ||
|
||
if entering.is_some_and(|em| compare_fn_nan_max(&self.max, em.1).is_le()) { | ||
// If the entering max >= the current max return early, since no value in the overlap can be larger than either. | ||
self.max = *entering.unwrap().1; | ||
self.max_ix = entering_start + entering.unwrap().0; | ||
return self.max; | ||
} else if self.max_ix >= start { | ||
// If the entering max isn't the largest but the current max is between start and end we can still ignore the overlap | ||
return self.max; | ||
} | ||
|
||
let mut recompute_max = false; | ||
for idx in self.last_start..start { | ||
// safety | ||
// we are in bounds | ||
let leaving_value = self.slice.get_unchecked(idx); | ||
// if the leaving value is the max value, we need to recompute the max. | ||
if matches!( | ||
compare_fn_nan_max(leaving_value, &self.max), | ||
Ordering::Equal | ||
) { | ||
recompute_max = true; | ||
break; | ||
} | ||
} | ||
|
||
let entering_max = self | ||
.slice | ||
.get_unchecked(self.last_end..end) | ||
.iter() | ||
.max_by(|a, b| compare_fn_nan_max(*a, *b)) | ||
.unwrap_or(self.slice.get_unchecked(std::cmp::max( | ||
self.last_start, | ||
self.last_end.saturating_sub(1), | ||
))); | ||
|
||
if recompute_max { | ||
match compare_fn_nan_max(&self.max, entering_max) { | ||
// do nothing | ||
Ordering::Equal => {} | ||
// leaving < entering | ||
Ordering::Less => { | ||
if matches!( | ||
compare_fn_nan_max(entering_max, &self.max), | ||
Ordering::Greater | ||
) { | ||
self.max = *entering_max | ||
} | ||
} | ||
// leaving > entering | ||
Ordering::Greater => { | ||
// leaving value could be the largest, we might need to recompute | ||
|
||
// check the values in between the window we did not yet | ||
// compute to find the max there. We compare that with the `entering_max` | ||
// if any value is equal to equal to `self.max` of previous window we can break | ||
// early | ||
let mut max_in_between = self.slice.get_unchecked(start); | ||
for idx in (start + 1)..self.last_end { | ||
// safety | ||
// we are in bounds | ||
let value = self.slice.get_unchecked(idx); | ||
|
||
if matches!(compare_fn_nan_max(value, max_in_between), Ordering::Greater) { | ||
max_in_between = value; | ||
} | ||
|
||
// the max is also in the in between values | ||
if matches!(compare_fn_nan_max(value, &self.max), Ordering::Equal) { | ||
self.last_start = start; | ||
self.last_end = end; | ||
return self.max; | ||
} | ||
} | ||
|
||
if matches!( | ||
compare_fn_nan_max(max_in_between, entering_max), | ||
Ordering::Greater | ||
) { | ||
self.max = *max_in_between | ||
} else { | ||
self.max = *entering_max | ||
} | ||
// Otherwise get the max of the overlapping window and the entering max | ||
match (get_max_and_ix(self.slice, start, old_last_end), entering) { | ||
(Some(pm), Some(em)) => { | ||
if compare_fn_nan_max(pm.1, em.1).is_le() { | ||
self.max = *em.1; | ||
self.max_ix = entering_start + em.0; | ||
} else { | ||
self.max = *pm.1; | ||
self.max_ix = start + pm.0; | ||
} | ||
} | ||
} else if matches!( | ||
compare_fn_nan_max(entering_max, &self.max), | ||
Ordering::Greater | ||
) { | ||
self.max = *entering_max | ||
(Some(pm), None) => { | ||
self.max = *pm.1; | ||
self.max_ix = start + pm.0; | ||
} | ||
(None, Some(em)) => { | ||
self.max = *em.1; | ||
self.max_ix = entering_start + em.0; | ||
} | ||
// We shouldn't reach this, but it means | ||
(None, None) => {} | ||
} | ||
self.last_start = start; | ||
self.last_end = end; | ||
|
||
self.max | ||
} | ||
} | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you rename all
ix
toidx
to make it more consistent with how we name indexes in the code base?