Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(rust, python) Improve rolling min and max for nonulls #9277

Merged
merged 5 commits into from
Jun 9, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 103 additions & 191 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,252 +18,164 @@ impl<'a, T: NativeType> RollingAggWindowNoNulls<'a, T> for SortedMinMax<'a, T> {
}
}

#[inline]
unsafe fn get_min_and_ix<T>(slice: &[T], start: usize, end: usize) -> Option<(usize, &T)>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you rename all ix to idx to make it more consistent with how we name indexes in the code base?

where
T: NativeType + IsFloat + PartialOrd,
{
slice
.get_unchecked(start..end)
.iter()
.enumerate()
.rev()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we iterate in reverse to ensure we get the latest index? If so, could you add a comment about that in this function.

.min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
}

pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
slice: &'a [T],
min: T,
min_ix: usize,
last_start: usize,
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> for MinWindow<'a, T> {
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self {
let min = *slice[start..end]
let (ix, min) = slice[start..end]
.iter()
.min_by(|a, b| compare_fn_nan_min(*a, *b))
.unwrap_or(&slice[start]);
.enumerate()
.rev()
.min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
.unwrap_or((0, &slice[start]));
Self {
slice,
min,
min: *min,
min_ix: start + ix,
last_start: start,
last_end: end,
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// recompute min
if start >= self.last_end
// the window only got smaller
|| end == self.last_end
{
self.min = *self
.slice
.get_unchecked(start..end)
.iter()
.min_by(|a, b| compare_fn_nan_min(*a, *b))
.unwrap_or(self.slice.get_unchecked(start));
self.last_start = start; // Don't care where the last one started
let old_last_end = self.last_end; // But we need this
self.last_end = end;

self.last_start = start;
self.last_end = end;
let entering_start = std::cmp::max(old_last_end, start);
let entering = get_min_and_ix(self.slice, entering_start, end);

if entering.is_some_and(|em| compare_fn_nan_min(&self.min, em.1).is_ge()) {
// If the entering min <= the current min return early, since no value in the overlap can be smaller than either.
self.min = *entering.unwrap().1;
self.min_ix = entering_start + entering.unwrap().0;
return self.min;
} else if self.min_ix >= start {
// If the entering min isn't the smallest but the current min is between start and end we can still ignore the overlap
return self.min;
}

let mut recompute_min = false;
// remove elements that should leave the window
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);

// if the leaving value is the
// max value, we need to recompute the max.
if matches!(
compare_fn_nan_min(leaving_value, &self.min),
Ordering::Equal
) {
recompute_min = true;
break;
}
}

let entering_min = self
.slice
.get_unchecked(self.last_end..end)
.iter()
.min_by(|a, b| compare_fn_nan_min(*a, *b))
.unwrap_or(self.slice.get_unchecked(std::cmp::min(
self.last_start,
self.last_end.saturating_sub(1),
)));

if recompute_min {
match compare_fn_nan_min(&self.min, entering_min) {
// do nothing
Ordering::Equal => {}
// leaving < entering
Ordering::Less => {
// leaving value could be the smallest, we might need to recompute

// check the values in between the window we did not yet
// compute to find the max there. We compare that with the `entering_max`
// if any value is equal to equal to `self.max` of previous window we can break
// early
let mut min_in_between = self.slice.get_unchecked(start);
for idx in (start + 1)..self.last_end {
// safety
// we are in bounds
let value = self.slice.get_unchecked(idx);

if matches!(compare_fn_nan_min(value, min_in_between), Ordering::Less) {
min_in_between = value;
}

// the min is also in the in between values
if matches!(compare_fn_nan_min(value, &self.min), Ordering::Equal) {
self.last_start = start;
self.last_end = end;
return self.min;
}
}

if matches!(
compare_fn_nan_min(min_in_between, entering_min),
Ordering::Less
) {
self.min = *min_in_between
} else {
self.min = *entering_min
}
}
// leaving > entering
Ordering::Greater => {
if matches!(compare_fn_nan_min(entering_min, &self.min), Ordering::Less) {
self.min = *entering_min
}
// Otherwise get the min of the overlapping window and the entering min
match (get_min_and_ix(self.slice, start, old_last_end), entering) {
(Some(pm), Some(em)) => {
if compare_fn_nan_min(pm.1, em.1).is_ge() {
self.min = *em.1;
self.min_ix = entering_start + em.0;
} else {
self.min = *pm.1;
self.min_ix = start + pm.0;
}
}
} else if matches!(compare_fn_nan_min(entering_min, &self.min), Ordering::Less) {
self.min = *entering_min
(Some(pm), None) => {
self.min = *pm.1;
self.min_ix = start + pm.0;
}
(None, Some(em)) => {
self.min = *em.1;
self.min_ix = entering_start + em.0;
}
// We shouldn't reach this, but it means
(None, None) => {}
}

self.last_start = start;
self.last_end = end;
self.min
}
}

#[inline]
unsafe fn get_max_and_ix<T>(slice: &[T], start: usize, end: usize) -> Option<(usize, &T)>
where
T: NativeType + IsFloat + PartialOrd,
{
slice
.get_unchecked(start..end)
.iter()
.enumerate()
.max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
}

pub struct MaxWindow<'a, T: NativeType> {
slice: &'a [T],
max: T,
max_ix: usize,
last_start: usize,
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> for MaxWindow<'a, T> {
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self {
let max = *slice[start..end]
let (ix, max) = slice[start..end]
.iter()
.max_by(|a, b| compare_fn_nan_max(*a, *b))
.unwrap_or(&slice[start]);
.enumerate()
.max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
.unwrap_or((0, &slice[start]));
Self {
slice,
max,
max: *max,
max_ix: start + ix,
last_start: start,
last_end: end,
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// recompute max
if start >= self.last_end
// the window only got smaller
|| end == self.last_end
{
self.max = *self
.slice
.get_unchecked(start..end)
.iter()
.max_by(|a, b| compare_fn_nan_max(*a, *b))
.unwrap_or(self.slice.get_unchecked(start));
self.last_start = start; // Don't care where the last one started
let old_last_end = self.last_end; // But we need this
self.last_end = end;

self.last_start = start;
self.last_end = end;
let entering_start = std::cmp::max(old_last_end, start);
let entering = get_max_and_ix(self.slice, entering_start, end);

if entering.is_some_and(|em| compare_fn_nan_max(&self.max, em.1).is_le()) {
// If the entering max >= the current max return early, since no value in the overlap can be larger than either.
self.max = *entering.unwrap().1;
self.max_ix = entering_start + entering.unwrap().0;
return self.max;
} else if self.max_ix >= start {
// If the entering max isn't the largest but the current max is between start and end we can still ignore the overlap
return self.max;
}

let mut recompute_max = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// if the leaving value is the max value, we need to recompute the max.
if matches!(
compare_fn_nan_max(leaving_value, &self.max),
Ordering::Equal
) {
recompute_max = true;
break;
}
}

let entering_max = self
.slice
.get_unchecked(self.last_end..end)
.iter()
.max_by(|a, b| compare_fn_nan_max(*a, *b))
.unwrap_or(self.slice.get_unchecked(std::cmp::max(
self.last_start,
self.last_end.saturating_sub(1),
)));

if recompute_max {
match compare_fn_nan_max(&self.max, entering_max) {
// do nothing
Ordering::Equal => {}
// leaving < entering
Ordering::Less => {
if matches!(
compare_fn_nan_max(entering_max, &self.max),
Ordering::Greater
) {
self.max = *entering_max
}
}
// leaving > entering
Ordering::Greater => {
// leaving value could be the largest, we might need to recompute

// check the values in between the window we did not yet
// compute to find the max there. We compare that with the `entering_max`
// if any value is equal to equal to `self.max` of previous window we can break
// early
let mut max_in_between = self.slice.get_unchecked(start);
for idx in (start + 1)..self.last_end {
// safety
// we are in bounds
let value = self.slice.get_unchecked(idx);

if matches!(compare_fn_nan_max(value, max_in_between), Ordering::Greater) {
max_in_between = value;
}

// the max is also in the in between values
if matches!(compare_fn_nan_max(value, &self.max), Ordering::Equal) {
self.last_start = start;
self.last_end = end;
return self.max;
}
}

if matches!(
compare_fn_nan_max(max_in_between, entering_max),
Ordering::Greater
) {
self.max = *max_in_between
} else {
self.max = *entering_max
}
// Otherwise get the max of the overlapping window and the entering max
match (get_max_and_ix(self.slice, start, old_last_end), entering) {
(Some(pm), Some(em)) => {
if compare_fn_nan_max(pm.1, em.1).is_le() {
self.max = *em.1;
self.max_ix = entering_start + em.0;
} else {
self.max = *pm.1;
self.max_ix = start + pm.0;
}
}
} else if matches!(
compare_fn_nan_max(entering_max, &self.max),
Ordering::Greater
) {
self.max = *entering_max
(Some(pm), None) => {
self.max = *pm.1;
self.max_ix = start + pm.0;
}
(None, Some(em)) => {
self.max = *em.1;
self.max_ix = entering_start + em.0;
}
// We shouldn't reach this, but it means
(None, None) => {}
}
self.last_start = start;
self.last_end = end;

self.max
}
}
Expand Down