Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(rust, python) Improve rolling min and max for nonulls #9277

Merged
merged 5 commits into from
Jun 9, 2023

Conversation

magarick
Copy link
Contributor

@magarick magarick commented Jun 7, 2023

This PR speeds up rolling min and max for the case when there are no nulls. It also makes the code clearer and more concise (in my opinion anyway). I've verified that it produces the same results and is faster on some sample data.

The data was generated as

np.random.seed(100)
ds = pl.date_range(datetime(2013, 1, 1), datetime(2023, 1, 1), '1h', eager = True)
vs = pl.Series(np.random.standard_normal(200 * len(ds)))
xs = pl.Series(np.random.randint(0, 50, len(ds)).repeat(200))

dat = pl.DataFrame(dict(d = ds.to_numpy().repeat(200), v = vs, x = xs))
dat = dat.set_sorted('d')

The test was plain rolling windows of lengths 5 and 7500 on each of the two series, as well as a groupby_dynamic that computed the max and min for each over every window with the listed every and period parameters.

Timings for 100 repetitions (in seconds). New is first, current is second:

(x, 5, min): 8.73, 13.35
(x, 5, max): 6.55, 12.77
(x, 7500, min): 8.90, 15.07
(x, 7500, max): 7.35, 15.73
(v, 5, min): 18.66, 27.17
(v, 5, max): 23.29, 27.42
(v, 7500, min): 18.39, 22.17
(v, 7500, max): 18.06, 22.13
(every = 1h, period = 2h): 15.27, 19.15
(every = 1d, period = 5d): 18.85, 25.66
(every = 1h, period = 5d): 128.04, 192.90

@magarick magarick requested a review from ritchie46 as a code owner June 7, 2023 09:09
@github-actions github-actions bot added performance Performance issues or improvements python Related to Python Polars rust Related to Rust Polars labels Jun 7, 2023
@ritchie46
Copy link
Member

ritchie46 commented Jun 7, 2023

I haven't reviewed yet, but the current functions are carefully constructed to prevent very bad behavior on sorted data. Random data is not the only thing we should look at.

We also should look at sorted data in both directions and semi-sorted data.

Could you show the code you used for benchmarking? It should not be done via groupby_dynamic but directly with rolling_min/max as that is what we are measuring.

@magarick
Copy link
Contributor Author

magarick commented Jun 7, 2023

I didn't touch the special-case branches that handle sorted data. I think I can handle them all at once as well as improve performance on mostly-sorted data but I figured one thing at a time.

Here's some pathological cases:
New:

>>> s1 = pl.Series(list(range(1000)) * 1000)
>>> s2 = s1.reverse()
>>> timeit('s1.rolling_min(200)', globals = globals(), number = 100)
8.955661284038797
>>> timeit('s1.rolling_min(2000)', globals = globals(), number = 100)
0.3911876070778817
>>> timeit('s1.rolling_max(200)', globals = globals(), number = 100)
0.2773919829633087
>>> timeit('s1.rolling_max(2000)', globals = globals(), number = 100)
0.27548131812363863
>>> timeit('s2.rolling_min(200)', globals = globals(), number = 100)
0.39969513611868024
>>> timeit('s2.rolling_min(2000)', globals = globals(), number = 100)
0.3881555439438671
>>> timeit('s2.rolling_max(200)', globals = globals(), number = 100)
10.08373599499464
>>> timeit('s2.rolling_max(2000)', globals = globals(), number = 100)
0.2590978641528636

>>> s3 = pl.Series(list(range(int(1e6))))
>>> s3 = s3.set_at_idx(s3.sample(1000, seed = 1), 0)
>>> timeit('s3.rolling_min(200)', globals = globals(), number = 100)
10.237212847918272
>>> timeit('s3.rolling_min(2000)', globals = globals(), number = 100)
13.379444011021405
>>> timeit('s3.rolling_min(20000)', globals = globals(), number = 100)
0.4033352010883391
>>> timeit('s3.rolling_max(200)', globals = globals(), number = 100)
0.2527862370479852
>>> timeit('s3.rolling_max(2000)', globals = globals(), number = 100)
0.27149013499729335
>>> timeit('s3.rolling_max(20000)', globals = globals(), number = 100)
0.26448553893715143

Current:

>>> s1 = pl.Series(list(range(1000)) * 1000)
>>> s2 = s1.reverse()
>>> timeit('s1.rolling_min(200)', globals = globals(), number = 100)
16.829105872893706
>>> timeit('s1.rolling_min(2000)', globals = globals(), number = 100)
0.5299010709859431
>>> timeit('s1.rolling_max(200)', globals = globals(), number = 100)
0.603124001994729
>>> timeit('s1.rolling_max(2000)', globals = globals(), number = 100)
0.6017518460284919
>>> timeit('s2.rolling_min(200)', globals = globals(), number = 100)
0.5548417470417917
>>> timeit('s2.rolling_min(2000)', globals = globals(), number = 100)
0.5275844470597804
>>> timeit('s2.rolling_max(200)', globals = globals(), number = 100)
11.248067330103368
>>> timeit('s2.rolling_max(2000)', globals = globals(), number = 100)
0.6214083470404148

>>> s3 = pl.Series(list(range(int(1e6))))
>>> s3 = pl.Series(list(range(int(1e6))))
>>> s3 = s3.set_at_idx(s3.sample(1000, seed = 1), 0)
>>> timeit('s3.rolling_min(200)', globals = globals(), number = 100)
17.742616538889706
>>> timeit('s3.rolling_min(2000)', globals = globals(), number = 100)
30.45354696107097
>>> timeit('s3.rolling_min(20000)', globals = globals(), number = 100)
1.1630480010062456
>>> timeit('s3.rolling_max(200)', globals = globals(), number = 100)
0.7935329640749842
>>> timeit('s3.rolling_max(2000)', globals = globals(), number = 100)
0.8272605510428548
>>> timeit('s3.rolling_max(20000)', globals = globals(), number = 100)
0.8058401739690453

Here's the testing script. It assume you've created files with the correct results using the current version

import polars as pl
from datetime import datetime
import numpy as np
from timeit import timeit
import sys   
from os import path

cmp_files = sys.argv[1]

np.random.seed(100)
ds = pl.date_range(datetime(2013, 1, 1), datetime(2023, 1, 1), '1h', eager = True)
vs = pl.Series(np.random.standard_normal(200 * len(ds)))
xs = pl.Series(np.random.randint(0, 50, len(ds)).repeat(200))

dat = pl.DataFrame(dict(d = ds.to_numpy().repeat(200), v = vs, x = xs))
dat = dat.set_sorted('d')

cmp_rolls_dt = pl.read_parquet(path.join(cmp_files, 'plain_rolls.parquet'))
rolls_dt = dat.select([
    pl.col('v').rolling_min(5).alias('v_roll_5_min'),
    pl.col('v').rolling_max(5).alias('v_roll_5_max'),
    pl.col('x').rolling_min(5).alias('x_roll_5_min'),
    pl.col('x').rolling_max(5).alias('x_roll_5_max'),
    pl.col('v').rolling_min(7500).alias('v_roll_7500_min'),
    pl.col('v').rolling_max(7500).alias('v_roll_7500_max'),
    pl.col('x').rolling_min(7500).alias('x_roll_7500_min'),
    pl.col('x').rolling_max(7500).alias('x_roll_7500_max')
])
assert rolls_dt.frame_equal(cmp_rolls_dt), "Plain rolls not equal"

for s in ['x', 'v']:
    for w in [5, 7500]:
        for op in ['min', 'max']:
            tt = timeit(f"{s}s.rolling_{op}({w})", number = 100, globals = globals())
            print(f'({s}, {w}, {op}): {tt}')

for pars in [('1h', '2h'), ('1d', '5d'), ('1h', '5d')]:
    gb = dat.groupby_dynamic(index_column='d', every = pars[0], period = pars[1])
    agg_dt = gb.agg(
    [pl.col('v').max().alias('vmax'),
     pl.col('v').min().alias('vmin'),
     pl.col('x').max().alias('xmax'),
     pl.col('x').min().alias('xmin'),
    ])
    cmp_dt = pl.read_parquet(f'./{pars[0]}_{pars[1]}_dyn.parquet')
    assert agg_dt.frame_equal(cmp_dt), f'{pars[0]}_{pars[1]} not equal'
    tt = timeit("""gb.agg(
    [pl.col('v').max().alias('vmax'),
     pl.col('v').min().alias('vmin'),
     pl.col('x').max().alias('xmax'),
     pl.col('x').min().alias('xmin'),
    ])""",
       globals = globals(), number = 100)
    print(f'(every = {pars[0]}, period = {pars[1]}): {tt}')

@ritchie46
Copy link
Member

Alright, thanks! I will take a look tomorrow.

@magarick
Copy link
Contributor Author

magarick commented Jun 7, 2023

Cool. It's very similar to what was being done before. It just caches where the last min/max was seen to do less work in some cases. If it helps, here's my understanding of what's happening abstractly. Note that I'm playing a little fast and loose with indices and notation below for brevity.
You have 3 intervals $A, B, C$ where $A$ is the "leaving" interval, $B$ is the "overlap" and $C$ is new values. Any of these can be empty. Graphically it might look like

[AAAA[BB)CCC)

We always have to compute the extremum (WLOG say a min) over $C$ since they're new values. Call this $m_C$ and let $m$ be the previous min over $A \cup B$. Also let $i$ be the largest index where we saw $m$ and $j$ be the index of the largest index where we saw $m_C$. Keeping positional information will help us do less work.

Now, if $m_C \leq m$, we set $m \leftarrow m_C$ and $i \leftarrow j$ and we're done. Nothing in $B$ is smaller than $m$ so it doesn't matter if it was only in $A$. The "or equal" part is important because it lets us update the most recent index where we saw a repeated min.

Next, if $m_C > m$ but $m \in B$ we don't need to update anything. The previous min was the smallest and didn't drop off.

Only now, if $m_C > m$ and $m \notin B$ must we compute the min over $B$ and its position. Take the smallest of these.

Three pathological cases:

  1. $B = \emptyset$, i.e. we have.[AAAA) [CCCC). You just take the min over $C$. Since $B$ is empty no work has to be done and we'll return $m_C$
  2. $C = \emptyset$ i.e. [AAA[BB). Since $C$ is empty the first part of the function does no work and $B$ is only checked if the previous min isn't in it (which we know by keeping the index)
  3. Both $B \cup C = \emptyset$. This is especially pathological since it means our new interval has start == end. I don't know if this can occur in practice. Right now I'm just propagating forward the previous min and index. Other behavior could be to return nan or make the value dependent on where start lies.

I think the algorithm could be modified to run well on sorted and partially sorted values as follows:

  1. For each interval of new values $C$, in addition to the min and its index, find the largest index where it was unsorted.
  2. If we have to recompute the min over $B$, check if the previous "largest unordered index" is in it
    a. If it isn't, we know that $B$ is sorted and its min value is the first value
    b. If it is, we still win a little bit. Let $k$ be the largest index for $B$ where $B_k > B_{k+1}$ Compute the minimum between $0$ and $k$. Then take the minimum of this with $B_{k+1}$ which is all you have to do since you know the tail of $B$ is sorted.
    c. Compare with the min over $C$, and update values accordingly.

The above extension might still leave some performance on the table and could maybe use a little more caching and cleverness. But for now it seems like an improvement.

Copy link
Member

@ritchie46 ritchie46 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your explanation! It would be great if you can add the link to the comment in the function comments.

I went through the code carefully and it makes sense and is indeed a lot simpler. 👍

I have only a small nit and a request for a comment, but other than that good to go. 👍

.get_unchecked(start..end)
.iter()
.enumerate()
.rev()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we iterate in reverse to ensure we get the latest index? If so, could you add a comment about that in this function.

@@ -18,252 +18,164 @@ impl<'a, T: NativeType> RollingAggWindowNoNulls<'a, T> for SortedMinMax<'a, T> {
}
}

#[inline]
unsafe fn get_min_and_ix<T>(slice: &[T], start: usize, end: usize) -> Option<(usize, &T)>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you rename all ix to idx to make it more consistent with how we name indexes in the code base?

@magarick
Copy link
Contributor Author

magarick commented Jun 8, 2023

Alrighty, done. Out of curiosity, why isn't a min/max-heap used here? Is it to avoid potentially large memory allocation or is it a cache issue?

Also, I compared rolling performance to bottleneck and this is faster on unsorted data but that's better on partially sorted. So the next step is improving performance on series with sorted runs.

@ritchie46
Copy link
Member

the next step is improving performance on series with sorted runs.

Yes, we know for instance up front in data is sorted with the sorted flag. So we could just do a strided take in that case.

Alrighty, done. Out of curiosity, why isn't a min/max-heap used here? Is it to avoid potentially large memory allocation or is it a cache issue?

Memory, but if I recall it wasn't fast either. I believe it does a binary search insert on every element. Which is much more expensive than what we do here.

@ritchie46 ritchie46 merged commit ab3b49f into pola-rs:main Jun 9, 2023
@ritchie46
Copy link
Member

Thanks @magarick 🙌

@magarick magarick deleted the minmax-improvements branch June 9, 2023 19:43
c-peters pushed a commit to c-peters/polars that referenced this pull request Jul 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance issues or improvements python Related to Python Polars rust Related to Rust Polars
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants