-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf(rust, python) Improve rolling min and max for nonulls #9277
Conversation
I haven't reviewed yet, but the current functions are carefully constructed to prevent very bad behavior on sorted data. Random data is not the only thing we should look at. We also should look at sorted data in both directions and semi-sorted data. Could you show the code you used for benchmarking? It should not be done via |
I didn't touch the special-case branches that handle sorted data. I think I can handle them all at once as well as improve performance on mostly-sorted data but I figured one thing at a time. Here's some pathological cases: >>> s1 = pl.Series(list(range(1000)) * 1000)
>>> s2 = s1.reverse()
>>> timeit('s1.rolling_min(200)', globals = globals(), number = 100)
8.955661284038797
>>> timeit('s1.rolling_min(2000)', globals = globals(), number = 100)
0.3911876070778817
>>> timeit('s1.rolling_max(200)', globals = globals(), number = 100)
0.2773919829633087
>>> timeit('s1.rolling_max(2000)', globals = globals(), number = 100)
0.27548131812363863
>>> timeit('s2.rolling_min(200)', globals = globals(), number = 100)
0.39969513611868024
>>> timeit('s2.rolling_min(2000)', globals = globals(), number = 100)
0.3881555439438671
>>> timeit('s2.rolling_max(200)', globals = globals(), number = 100)
10.08373599499464
>>> timeit('s2.rolling_max(2000)', globals = globals(), number = 100)
0.2590978641528636
>>> s3 = pl.Series(list(range(int(1e6))))
>>> s3 = s3.set_at_idx(s3.sample(1000, seed = 1), 0)
>>> timeit('s3.rolling_min(200)', globals = globals(), number = 100)
10.237212847918272
>>> timeit('s3.rolling_min(2000)', globals = globals(), number = 100)
13.379444011021405
>>> timeit('s3.rolling_min(20000)', globals = globals(), number = 100)
0.4033352010883391
>>> timeit('s3.rolling_max(200)', globals = globals(), number = 100)
0.2527862370479852
>>> timeit('s3.rolling_max(2000)', globals = globals(), number = 100)
0.27149013499729335
>>> timeit('s3.rolling_max(20000)', globals = globals(), number = 100)
0.26448553893715143 Current: >>> s1 = pl.Series(list(range(1000)) * 1000)
>>> s2 = s1.reverse()
>>> timeit('s1.rolling_min(200)', globals = globals(), number = 100)
16.829105872893706
>>> timeit('s1.rolling_min(2000)', globals = globals(), number = 100)
0.5299010709859431
>>> timeit('s1.rolling_max(200)', globals = globals(), number = 100)
0.603124001994729
>>> timeit('s1.rolling_max(2000)', globals = globals(), number = 100)
0.6017518460284919
>>> timeit('s2.rolling_min(200)', globals = globals(), number = 100)
0.5548417470417917
>>> timeit('s2.rolling_min(2000)', globals = globals(), number = 100)
0.5275844470597804
>>> timeit('s2.rolling_max(200)', globals = globals(), number = 100)
11.248067330103368
>>> timeit('s2.rolling_max(2000)', globals = globals(), number = 100)
0.6214083470404148
>>> s3 = pl.Series(list(range(int(1e6))))
>>> s3 = pl.Series(list(range(int(1e6))))
>>> s3 = s3.set_at_idx(s3.sample(1000, seed = 1), 0)
>>> timeit('s3.rolling_min(200)', globals = globals(), number = 100)
17.742616538889706
>>> timeit('s3.rolling_min(2000)', globals = globals(), number = 100)
30.45354696107097
>>> timeit('s3.rolling_min(20000)', globals = globals(), number = 100)
1.1630480010062456
>>> timeit('s3.rolling_max(200)', globals = globals(), number = 100)
0.7935329640749842
>>> timeit('s3.rolling_max(2000)', globals = globals(), number = 100)
0.8272605510428548
>>> timeit('s3.rolling_max(20000)', globals = globals(), number = 100)
0.8058401739690453 Here's the testing script. It assume you've created files with the correct results using the current version import polars as pl
from datetime import datetime
import numpy as np
from timeit import timeit
import sys
from os import path
cmp_files = sys.argv[1]
np.random.seed(100)
ds = pl.date_range(datetime(2013, 1, 1), datetime(2023, 1, 1), '1h', eager = True)
vs = pl.Series(np.random.standard_normal(200 * len(ds)))
xs = pl.Series(np.random.randint(0, 50, len(ds)).repeat(200))
dat = pl.DataFrame(dict(d = ds.to_numpy().repeat(200), v = vs, x = xs))
dat = dat.set_sorted('d')
cmp_rolls_dt = pl.read_parquet(path.join(cmp_files, 'plain_rolls.parquet'))
rolls_dt = dat.select([
pl.col('v').rolling_min(5).alias('v_roll_5_min'),
pl.col('v').rolling_max(5).alias('v_roll_5_max'),
pl.col('x').rolling_min(5).alias('x_roll_5_min'),
pl.col('x').rolling_max(5).alias('x_roll_5_max'),
pl.col('v').rolling_min(7500).alias('v_roll_7500_min'),
pl.col('v').rolling_max(7500).alias('v_roll_7500_max'),
pl.col('x').rolling_min(7500).alias('x_roll_7500_min'),
pl.col('x').rolling_max(7500).alias('x_roll_7500_max')
])
assert rolls_dt.frame_equal(cmp_rolls_dt), "Plain rolls not equal"
for s in ['x', 'v']:
for w in [5, 7500]:
for op in ['min', 'max']:
tt = timeit(f"{s}s.rolling_{op}({w})", number = 100, globals = globals())
print(f'({s}, {w}, {op}): {tt}')
for pars in [('1h', '2h'), ('1d', '5d'), ('1h', '5d')]:
gb = dat.groupby_dynamic(index_column='d', every = pars[0], period = pars[1])
agg_dt = gb.agg(
[pl.col('v').max().alias('vmax'),
pl.col('v').min().alias('vmin'),
pl.col('x').max().alias('xmax'),
pl.col('x').min().alias('xmin'),
])
cmp_dt = pl.read_parquet(f'./{pars[0]}_{pars[1]}_dyn.parquet')
assert agg_dt.frame_equal(cmp_dt), f'{pars[0]}_{pars[1]} not equal'
tt = timeit("""gb.agg(
[pl.col('v').max().alias('vmax'),
pl.col('v').min().alias('vmin'),
pl.col('x').max().alias('xmax'),
pl.col('x').min().alias('xmin'),
])""",
globals = globals(), number = 100)
print(f'(every = {pars[0]}, period = {pars[1]}): {tt}') |
Alright, thanks! I will take a look tomorrow. |
Cool. It's very similar to what was being done before. It just caches where the last min/max was seen to do less work in some cases. If it helps, here's my understanding of what's happening abstractly. Note that I'm playing a little fast and loose with indices and notation below for brevity.
We always have to compute the extremum (WLOG say a min) over Now, if Next, if Only now, if Three pathological cases:
I think the algorithm could be modified to run well on sorted and partially sorted values as follows:
The above extension might still leave some performance on the table and could maybe use a little more caching and cleverness. But for now it seems like an improvement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your explanation! It would be great if you can add the link to the comment in the function comments.
I went through the code carefully and it makes sense and is indeed a lot simpler. 👍
I have only a small nit and a request for a comment, but other than that good to go. 👍
.get_unchecked(start..end) | ||
.iter() | ||
.enumerate() | ||
.rev() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we iterate in reverse to ensure we get the latest index? If so, could you add a comment about that in this function.
@@ -18,252 +18,164 @@ impl<'a, T: NativeType> RollingAggWindowNoNulls<'a, T> for SortedMinMax<'a, T> { | |||
} | |||
} | |||
|
|||
#[inline] | |||
unsafe fn get_min_and_ix<T>(slice: &[T], start: usize, end: usize) -> Option<(usize, &T)> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you rename all ix
to idx
to make it more consistent with how we name indexes in the code base?
Alrighty, done. Out of curiosity, why isn't a min/max-heap used here? Is it to avoid potentially large memory allocation or is it a cache issue? Also, I compared rolling performance to |
Yes, we know for instance up front in data is sorted with the sorted flag. So we could just do a strided take in that case.
Memory, but if I recall it wasn't fast either. I believe it does a binary search insert on every element. Which is much more expensive than what we do here. |
Thanks @magarick 🙌 |
This PR speeds up rolling min and max for the case when there are no nulls. It also makes the code clearer and more concise (in my opinion anyway). I've verified that it produces the same results and is faster on some sample data.
The data was generated as
The test was plain rolling windows of lengths 5 and 7500 on each of the two series, as well as a groupby_dynamic that computed the max and min for each over every window with the listed
every
andperiod
parameters.Timings for 100 repetitions (in seconds). New is first, current is second: