-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GH-35515: [C++][Python] Add non decomposable aggregation UDF #35514
Changes from all commits
0c1c1b2
94c9710
241e970
15194fa
e91b882
a2b89c6
c578057
ff04234
b1d51f7
3daefea
8fd8c96
8381f08
9d7fd9d
dc1d734
1203346
84c1e91
17ff274
7f65599
febf6cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -954,7 +954,9 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( | |
return Status::Invalid("Expected aggregate call ", call.id().uri, "#", | ||
call.id().name, " to have at least one argument"); | ||
} | ||
case 1: { | ||
default: { | ||
// Handles all arity > 0 | ||
|
||
std::shared_ptr<compute::FunctionOptions> options = nullptr; | ||
if (arrow_function_name == "stddev" || arrow_function_name == "variance") { | ||
// See the following URL for the spec of stddev and variance: | ||
|
@@ -981,21 +983,22 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( | |
} | ||
fixed_arrow_func += arrow_function_name; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line was glued to the previous block before it's part of the function name building process. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I am not sure what do you mean here - do you want me to change anything? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To keep the line glued to the previous block as it was before. It's a nitpick, but makes a potential function extraction in the future more obvious. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Updated |
||
|
||
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); | ||
const FieldRef* arg_ref = arg.field_ref(); | ||
if (!arg_ref) { | ||
return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", | ||
call.id().name, " to have a direct reference"); | ||
std::vector<FieldRef> target; | ||
for (int i = 0; i < call.size(); i++) { | ||
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); | ||
const FieldRef* arg_ref = arg.field_ref(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @felipecrv My understanding is with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. This is one of the many ways copies can creep in. But looking at the definition of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More details on the move const issue: https://stackoverflow.com/questions/27810535/why-does-calling-stdmove-on-a-const-object-call-the-copy-constructor-when-pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Removed |
||
if (!arg_ref) { | ||
return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", | ||
call.id().name, " to have a direct reference"); | ||
} | ||
// Copy arg_ref here because field_ref() return const FieldRef* | ||
target.emplace_back(*arg_ref); | ||
icexelloss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
return compute::Aggregate{std::move(fixed_arrow_func), | ||
options ? std::move(options) : nullptr, *arg_ref, ""}; | ||
options ? std::move(options) : nullptr, | ||
std::move(target), ""}; | ||
} | ||
default: | ||
break; | ||
} | ||
return Status::NotImplemented( | ||
"Only nullary and unary aggregate functions are currently supported"); | ||
}; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,11 @@ from pyarrow.lib cimport * | |
from pyarrow.includes.common cimport * | ||
from pyarrow.includes.libarrow cimport * | ||
|
||
cdef class ScalarUdfContext(_Weakrefable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed this because the Context is now used for scalar, tabular and aggregate UDF |
||
cdef class UdfContext(_Weakrefable): | ||
cdef: | ||
CScalarUdfContext c_context | ||
CUdfContext c_context | ||
|
||
cdef void init(self, const CScalarUdfContext& c_context) | ||
cdef void init(self, const CUdfContext& c_context) | ||
|
||
|
||
cdef class FunctionOptions(_Weakrefable): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,3 +278,59 @@ def unary_function(ctx, x): | |
{"array": pa.int64()}, | ||
pa.int64()) | ||
return unary_function, func_name | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def unary_agg_func_fixture(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used in both test_udf.py and test_substrait.py |
||
""" | ||
Register a unary aggregate function | ||
""" | ||
from pyarrow import compute as pc | ||
import numpy as np | ||
|
||
def func(ctx, x): | ||
return pa.scalar(np.nanmean(x)) | ||
|
||
func_name = "y=avg(x)" | ||
func_doc = {"summary": "y=avg(x)", | ||
"description": "find mean of x"} | ||
|
||
pc.register_aggregate_function(func, | ||
func_name, | ||
func_doc, | ||
{ | ||
"x": pa.float64(), | ||
}, | ||
pa.float64() | ||
) | ||
return func, func_name | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def varargs_agg_func_fixture(): | ||
""" | ||
Register a unary aggregate function | ||
""" | ||
from pyarrow import compute as pc | ||
import numpy as np | ||
|
||
def func(ctx, *args): | ||
sum = 0.0 | ||
for arg in args: | ||
sum += np.nanmean(arg) | ||
return pa.scalar(sum) | ||
|
||
func_name = "y=sum_mean(x...)" | ||
func_doc = {"summary": "Varargs aggregate", | ||
"description": "Varargs aggregate"} | ||
|
||
pc.register_aggregate_function(func, | ||
func_name, | ||
func_doc, | ||
{ | ||
"x": pa.int64(), | ||
"y": pa.float64() | ||
Comment on lines
+331
to
+332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Admittedly this is weird/confusing but here is why: This not a truely "varargs" function, as this function must take two arguments x and y with the specified type.
We would then wrap the foo into a function that Acero is expecting (a varargs UDF)
And also register it in Acero on the fly when user executes the expression. In other words, this is a function with known number of inputs (two in this case) but has function signature that takes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Realized I double explained this in https://github.com/apache/arrow/pull/35514/files/f1d3bcbd803a95e378d09121c5061787a68755d2#r1196685513 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so the test case is verifying that the python function can take in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is correct. It is testing a UDF must takes fixed number of inputs but with a |
||
}, | ||
pa.float64() | ||
) | ||
return func, func_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so that we can support UDFs which can have arity > 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes