-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add stddev
and variance
#1525
Add stddev
and variance
#1525
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will review it carefully again after learn the algorithm about these two functions.
Perhaps @Dandandan would be interested in this as he was involved in db-benchmark which this will help |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really nice piece of work @realno - thank you so much ❤️
I especially like the thorough testing.
I did not review the referenced papers, but I did run some basic tests against postgres and I got some strange results
cargo run datafusion-cli
Sometimes the answers are different from query to query
❯ select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq;
+--------------------+
| STDDEV(sq.column1) |
+--------------------+
| 0.7760297817881877 |
+--------------------+
1 row in set. Query took 0.008 seconds.
❯ select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq;
+--------------------+
| STDDEV(sq.column1) |
+--------------------+
| NaN |
+--------------------+
And neither of those answers matches postgres:
alamb=# select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq;
stddev
------------------------
0.95043849529221686157
(1 row)
|
||
if !is_empty { | ||
let new_count = self.count + 1; | ||
let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using ScalarValue
s like this to accumulate each operation is likely to be quite slow during runtime. However, I think it would be fine to put in as a first initial implementation and then implement an optimized version using update_batch
and arrow compute kernels as a follow on PR
/// updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely agree. I have plan to investigate this as followup PR. The current challenge is the algorithm will loose parallelizability if using a batch friendly algorithm. And I need to spend more time to understand the code. One question I have is will there be a chance update
and batch_update
can be used in the same job, i.e. if one job can call update
on some data and batch_update
on some other data. Reason for that is the online version of the algorithm requires an intermediate value to be calculated so it is not compatible with batch mode, that is, we can only do all batch or all online.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand the discussion here about online and parallelism with respect to update
and update_batch
update_batch
is actually what the HashAggregator calls during execution. The default implementation of update_batch
simply iterates over each row of the input array, converting it to a ScalarValue
, and calls update
convert the https://github.com/apache/arrow-datafusion/blob/415c5e124af18a05500514f78604366d860dcf5a/datafusion/src/physical_plan/mod.rs#L572-L583
My mental picture of aggregation is like this:
┌───────────────────────┐ ┌──────────────────────┐
│ │ │ │
│ Input RecordBatch 0 │ update_batch│ Accumulator │
│ │────────────▶│ │───────┐
│ │ │ │ │
└───────────────────────┘ └──────────────────────┘ │
│
│ merge
│
┌───────────────────────┐ ┌──────────────────────┐ │ ┌──────────────────────┐
│ │ │ │ │ │ │
│ Input RecordBatch 1 │ update_batch│ Accumulator │ │ │ Accumulator │
│ │────────────▶│ │───────┼──────▶│ │
│ │ │ │ │ │ │
└───────────────────────┘ └──────────────────────┘ │ └──────────────────────┘
│
│
│
│
... ... │
│
│
│
┌───────────────────────┐ ┌──────────────────────┐ │
│ │ │ │ │
│ Input RecordBatch N │ update_batch│ Accumulator │ │
│ │────────────▶│ │───────┘
│ │ │ │
└───────────────────────┘ └──────────────────────┘
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry let me try to explain it a bit better.
The online algorithm uses an intermediate value for each record called m2. Here is how m2 is calculated:
m2(n) = m2(n-1) + (delta1*delta2)
So the calculation of m2(n) depends on the value of m2(n-1). And it is similar for delta1 and delta2 - they depends on mean(n-1). i.e. the algorithm is iterative. So what I meant was this algorithm can't benefit from complete vectorization - for each batch it has to be iterative.
But I see your point that we can avoid per-record type conversion and matching since in batch_update
we can use the raw array. I think I can make the change in this PR. Thanks for pointing it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what I meant was this algorithm can't benefit from complete vectorization - for each batch it has to be iterative.
Ah that makes sense 👍
@alamb All the comments are addressed, there are some questions/discussion please take a look when you have time. Otherwise the PR is ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the changes and ran some new tests -- very nice @realno
I think this PR is good enough to merge as is -- it is well tested and adds a new feature.
My only concern is the introduction of ScalarValue
arithmetic functions rather than using the arrow compute kernels. I would love to know what @jimexist or @Dandandan thinks.
The way to avoid adding more logic to ScalarValue is to implement update_batch
for VarianceAccumulator
in terms of Arrays (and change update
to convert its values into arrays). Perhaps we could merge this PR in with this approach, and file a follow on ticket (I would be happy to do so) to remove scalar value math and improve the performance
|
||
if !is_empty { | ||
let new_count = self.count + 1; | ||
let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand the discussion here about online and parallelism with respect to update
and update_batch
update_batch
is actually what the HashAggregator calls during execution. The default implementation of update_batch
simply iterates over each row of the input array, converting it to a ScalarValue
, and calls update
convert the https://github.com/apache/arrow-datafusion/blob/415c5e124af18a05500514f78604366d860dcf5a/datafusion/src/physical_plan/mod.rs#L572-L583
My mental picture of aggregation is like this:
┌───────────────────────┐ ┌──────────────────────┐
│ │ │ │
│ Input RecordBatch 0 │ update_batch│ Accumulator │
│ │────────────▶│ │───────┐
│ │ │ │ │
└───────────────────────┘ └──────────────────────┘ │
│
│ merge
│
┌───────────────────────┐ ┌──────────────────────┐ │ ┌──────────────────────┐
│ │ │ │ │ │ │
│ Input RecordBatch 1 │ update_batch│ Accumulator │ │ │ Accumulator │
│ │────────────▶│ │───────┼──────▶│ │
│ │ │ │ │ │ │
└───────────────────────┘ └──────────────────────┘ │ └──────────────────────┘
│
│
│
│
... ... │
│
│
│
┌───────────────────────┐ ┌──────────────────────┐ │
│ │ │ │ │
│ Input RecordBatch N │ update_batch│ Accumulator │ │
│ │────────────▶│ │───────┘
│ │ │ │
└───────────────────────┘ └──────────────────────┘
@alamb thanks for the thorough review! After it is merged I will open a followup PR to add I want to discuss the aggregator interface a bit further. The With that said I am still new to the code base and could have missed something here. It would be nice to hear more people's opinion. Also please let me know if there is a more appropriate forum to discuss this. |
I another other rationale for Perhaps we could add some comments to |
I see, that makes sense. It is a good idea to add some comments. |
Unless anyone has another other thoughts on this PR I'll plan to merge it later today |
Thanks again @realno -- would you like me to file a ticket for the the follow on work? (specifically removing |
I proposed some doc clarification in #1542 |
I have the changes ready, will open a PR today. If you prefer to have a ticket as reference I can create it too. |
Which issue does this PR close?
This PR covers one of the functions (stddev) in #1486
Rationale for this change
This change is to add functions to calculate standard deviation and variance to DataFusion.
What changes are included in this PR?
Two aggregation functions
variance
andstddev
are added to DataFusion.Also added some arithmetic functions (add, mul and div) to
ScalarValue
since they are likely to be commonly used by many other operators. There are limitations to the current form of the functions - please see the code. It may be worth creating a new issue to capture the work properly.The algrithm used is an online implementation and numerically stable. It is based on this paper:
Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
Another benefit of the algorithm is it is parallelizable.
It has been analyzed here:
Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
Some limitations:
Are there any user-facing changes?
No change to existing functionalities or APIs. Six new functions are added and are available through SQL interface.
The functions added are following:
var(col1)
: calculate the variance (sample) of col1var_samp(col1)
: calculate the variance (sample) of col1var_pop(col1)
: calculate the variance (population) of col1stddev(col1)
: calculate the standard deviation (sample) of col1stddev_samp(col1)
: calculate the standard deviation (sample) of col1stddev_pop(col1)
: calculate the standard deviation (population) of col1