Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-35515: [C++][Python] Add non decomposable aggregation UDF #35514

Merged
merged 19 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,9 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
return Status::Invalid("Expected aggregate call ", call.id().uri, "#",
call.id().name, " to have at least one argument");
}
case 1: {
default: {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so that we can support UDFs which can have arity > 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

// Handles all arity > 0

std::shared_ptr<compute::FunctionOptions> options = nullptr;
if (arrow_function_name == "stddev" || arrow_function_name == "variance") {
// See the following URL for the spec of stddev and variance:
Expand All @@ -981,21 +983,22 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
}
fixed_arrow_func += arrow_function_name;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was glued to the previous block before it's part of the function name building process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am not sure what do you mean here - do you want me to change anything?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the line glued to the previous block as it was before. It's a nitpick, but makes a potential function extraction in the future more obvious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Updated


ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
const FieldRef* arg_ref = arg.field_ref();
if (!arg_ref) {
return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
call.id().name, " to have a direct reference");
std::vector<FieldRef> target;
for (int i = 0; i < call.size(); i++) {
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i));
const FieldRef* arg_ref = arg.field_ref();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove const so it can be moved below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felipecrv My understanding is with const FieldRef* then the move below has no effect / need to copy the data (because it is pointer to const FieldRef) and with FieldRef* the move operator can then actually move the object and not copying it?

Copy link
Contributor

@felipecrv felipecrv Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. This is one of the many ways copies can creep in.

But looking at the definition of field_ref I noticed it can only return const. Removing const won't really be possible. So I recommend removing the std::move(*arg_ref) below so it doesn't give the wrong impression that a move is happening.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Removed std::move(*arg_ref)

if (!arg_ref) {
return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
call.id().name, " to have a direct reference");
}
// Copy arg_ref here because field_ref() return const FieldRef*
target.emplace_back(*arg_ref);
icexelloss marked this conversation as resolved.
Show resolved Hide resolved
}

return compute::Aggregate{std::move(fixed_arrow_func),
options ? std::move(options) : nullptr, *arg_ref, ""};
options ? std::move(options) : nullptr,
std::move(target), ""};
}
default:
break;
}
return Status::NotImplemented(
"Only nullary and unary aggregate functions are currently supported");
};
}

Expand Down
6 changes: 3 additions & 3 deletions python/pyarrow/_compute.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ from pyarrow.lib cimport *
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *

cdef class ScalarUdfContext(_Weakrefable):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this because the Context is now used for scalar, tabular and aggregate UDF

cdef class UdfContext(_Weakrefable):
cdef:
CScalarUdfContext c_context
CUdfContext c_context

cdef void init(self, const CScalarUdfContext& c_context)
cdef void init(self, const CUdfContext& c_context)


cdef class FunctionOptions(_Weakrefable):
Expand Down
161 changes: 119 additions & 42 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2559,7 +2559,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *:
deref(pyarrow_unwrap_schema(schema).get())))


cdef class ScalarUdfContext:
cdef class UdfContext:
"""
Per-invocation function context/state.

Expand All @@ -2571,7 +2571,7 @@ cdef class ScalarUdfContext:
raise TypeError("Do not call {}'s constructor directly"
.format(self.__class__.__name__))

cdef void init(self, const CScalarUdfContext &c_context):
cdef void init(self, const CUdfContext &c_context):
self.c_context = c_context

@property
Expand Down Expand Up @@ -2620,26 +2620,26 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *:
return f_doc


cdef object box_scalar_udf_context(const CScalarUdfContext& c_context):
cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext)
cdef object box_udf_context(const CUdfContext& c_context):
cdef UdfContext context = UdfContext.__new__(UdfContext)
context.init(c_context)
return context


cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
cdef _udf_callback(user_function, const CUdfContext& c_context, inputs):
"""
Helper callback function used to wrap the ScalarUdfContext from Python to C++
Helper callback function used to wrap the UdfContext from Python to C++
execution.
"""
context = box_scalar_udf_context(c_context)
context = box_udf_context(c_context)
return user_function(context, *inputs)


def _get_scalar_udf_context(memory_pool, batch_length):
cdef CScalarUdfContext c_context
def _get_udf_context(memory_pool, batch_length):
cdef CUdfContext c_context
c_context.pool = maybe_unbox_memory_pool(memory_pool)
c_context.batch_length = batch_length
context = box_scalar_udf_context(c_context)
context = box_udf_context(c_context)
return context


Expand All @@ -2665,11 +2665,19 @@ cdef get_register_tabular_function():
return reg


cdef get_register_aggregate_function():
cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
reg.register_func = RegisterAggregateFunction
return reg


def register_scalar_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
"""
Register a user-defined scalar function.

This API is EXPERIMENTAL.
icexelloss marked this conversation as resolved.
Show resolved Hide resolved

A scalar function is a function that executes elementwise
operations on arrays or scalars, i.e. a scalar function must
be computed row-by-row with no state where each output row
Expand All @@ -2684,17 +2692,18 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty
func : callable
A callable implementing the user-defined function.
The first argument is the context argument of type
ScalarUdfContext.
UdfContext.
Then, it must take arguments equal to the number of
in_types defined. It must return an Array or Scalar
matching the out_type. It must return a Scalar if
all arguments are scalar, else it must return an Array.

To define a varargs function, pass a callable that takes
varargs. The last in_type will be the type of all varargs
*args. The last in_type will be the type of all varargs
arguments.
function_name : str
Name of the function. This name must be globally unique.
Name of the function. There should only be one function
registered with this name in the function registry.
function_doc : dict
A dictionary object with keys "summary" (str),
and "description" (str).
Expand Down Expand Up @@ -2738,18 +2747,97 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty
21
]
"""
return _register_scalar_like_function(get_register_scalar_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)
return _register_user_defined_function(get_register_scalar_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)


def register_aggregate_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
"""
Register a user-defined non-decomposable aggregate function.

This API is EXPERIMENTAL.

A non-decomposable aggregation function is a function that executes
aggregate operations on the whole data that it is aggregating.
In other words, non-decomposable aggregate function cannot be
split into consume/merge/finalize steps.

This is often used with ordered or segmented aggregation where groups
can be emit before accumulating all of the input data.

Parameters
----------
func : callable
A callable implementing the user-defined function.
The first argument is the context argument of type
UdfContext.
Then, it must take arguments equal to the number of
in_types defined. It must return a Scalar matching the
out_type.
To define a varargs function, pass a callable that takes
icexelloss marked this conversation as resolved.
Show resolved Hide resolved
*args. The in_type needs to match in type of inputs when
the function gets called.
function_name : str
Name of the function. This name must be unique, i.e.,
there should only be one function registered with
this name in the function registry.
function_doc : dict
A dictionary object with keys "summary" (str),
and "description" (str).
in_types : Dict[str, DataType]
A dictionary mapping function argument names to
their respective DataType.
The argument names will be used to generate
documentation for the function. The number of
arguments specified here determines the function
arity.
out_type : DataType
Output type of the function.
func_registry : FunctionRegistry
Optional function registry to use instead of the default global one.

Examples
--------
>>> import numpy as np
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>>
>>> func_doc = {}
>>> func_doc["summary"] = "simple median udf"
>>> func_doc["description"] = "compute median"
>>>
>>> def compute_median(ctx, array):
... return pa.scalar(np.median(array))
>>>
>>> func_name = "py_compute_median"
>>> in_types = {"array": pa.int64()}
>>> out_type = pa.float64()
>>> pc.register_aggregate_function(compute_median, func_name, func_doc,
... in_types, out_type)
>>>
>>> func = pc.get_function(func_name)
>>> func.name
'py_compute_median'
>>> answer = pc.call_function(func_name, [pa.array([20, 40])])
>>> answer
<pyarrow.DoubleScalar: 30.0>
"""
return _register_user_defined_function(get_register_aggregate_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)


def register_tabular_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
"""
Register a user-defined tabular function.

This API is EXPERIMENTAL.

A tabular function is one accepting a context argument of type
ScalarUdfContext and returning a generator of struct arrays.
UdfContext and returning a generator of struct arrays.
The in_types argument must be empty and the out_type argument
specifies a schema. Each struct array must have field types
correspoding to the schema.
Expand All @@ -2759,11 +2847,12 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t
func : callable
A callable implementing the user-defined function.
The only argument is the context argument of type
ScalarUdfContext. It must return a callable that
UdfContext. It must return a callable that
returns on each invocation a StructArray matching
the out_type, where an empty array indicates end.
function_name : str
Name of the function. This name must be globally unique.
Name of the function. There should only be one function
registered with this name in the function registry.
function_doc : dict
A dictionary object with keys "summary" (str),
and "description" (str).
Expand All @@ -2783,46 +2872,34 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t
with nogil:
c_type = <shared_ptr[CDataType]>make_shared[CStructType](deref(c_schema).fields())
out_type = pyarrow_wrap_data_type(c_type)
return _register_scalar_like_function(get_register_tabular_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)
return _register_user_defined_function(get_register_tabular_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)


def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types,
out_type, func_registry=None):
def _register_user_defined_function(register_func, func, function_name, function_doc, in_types,
out_type, func_registry=None):
"""
Register a user-defined scalar-like function.
Register a user-defined function.

A scalar-like function is a callable accepting a first
context argument of type ScalarUdfContext as well as
possibly additional Arrow arguments, and returning a
an Arrow result appropriate for the kind of function.
A scalar function and a tabular function are examples
for scalar-like functions.
This function is normally not called directly but via
register_scalar_function or register_tabular_function.
This method itself doesn't care about the type of the UDF
(i.e., scalar vs tabular vs aggregate)

Parameters
----------
register_func: object
An object holding a CRegisterUdf in a "register_func" attribute. Use
get_register_scalar_function() for a scalar function and
get_register_tabular_function() for a tabular function.
An object holding a CRegisterUdf in a "register_func" attribute.
func : callable
A callable implementing the user-defined function.
See register_scalar_function and
register_tabular_function for details.

function_name : str
Name of the function. This name must be globally unique.
Name of the function. There should only be one function
registered with this name in the function registry.
function_doc : dict
A dictionary object with keys "summary" (str),
and "description" (str).
in_types : Dict[str, DataType]
A dictionary mapping function argument names to
their respective DataType.
See register_scalar_function and
register_tabular_function for details.
out_type : DataType
Output type of the function.
func_registry : FunctionRegistry
Expand Down
3 changes: 2 additions & 1 deletion python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@
call_tabular_function,
register_scalar_function,
register_tabular_function,
ScalarUdfContext,
register_aggregate_function,
UdfContext,
# Expressions
Expression,
)
Expand Down
56 changes: 56 additions & 0 deletions python/pyarrow/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,59 @@ def unary_function(ctx, x):
{"array": pa.int64()},
pa.int64())
return unary_function, func_name


@pytest.fixture(scope="session")
def unary_agg_func_fixture():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used in both test_udf.py and test_substrait.py

"""
Register a unary aggregate function
"""
from pyarrow import compute as pc
import numpy as np

def func(ctx, x):
return pa.scalar(np.nanmean(x))

func_name = "y=avg(x)"
func_doc = {"summary": "y=avg(x)",
"description": "find mean of x"}

pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.float64(),
},
pa.float64()
)
return func, func_name


@pytest.fixture(scope="session")
def varargs_agg_func_fixture():
"""
Register a unary aggregate function
"""
from pyarrow import compute as pc
import numpy as np

def func(ctx, *args):
sum = 0.0
for arg in args:
sum += np.nanmean(arg)
return pa.scalar(sum)

func_name = "y=sum_mean(x...)"
func_doc = {"summary": "Varargs aggregate",
"description": "Varargs aggregate"}

pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.int64(),
"y": pa.float64()
Comment on lines +331 to +332
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are x and y if this is a varargs function?

Copy link
Contributor Author

@icexelloss icexelloss Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Admittedly this is weird/confusing but here is why:

This not a truely "varargs" function, as this function must take two arguments x and y with the specified type.
This test case matches how we would this it internally. The end user would define sth like

def foo(x: pd.Series, y: pd.Series):
    return np.nanmean(x) + np.nanmean(y)

summarize(table, agg=foo, columns=['x', 'y'], by='time')

We would then wrap the foo into a function that Acero is expecting (a varargs UDF)

def get_acero_func(func):
      # This wraps turns the func to what Acero is expecting
      def acero_func(ctx, *args):
            return pa.scalar(func(*[arg.to_pandas() for arg in args]))
            
      return acero_func

And also register it in Acero on the fly when user executes the expression.

In other words, this is a function with known number of inputs (two in this case) but has function signature that takes *args which always have args x and y.

Copy link
Contributor Author

@icexelloss icexelloss Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so the test case is verifying that the python function can take in *args if needed (even though it still lists the args when registering)?

Copy link
Contributor Author

@icexelloss icexelloss Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct. It is testing a UDF must takes fixed number of inputs but with a *args function signature

},
pa.float64()
)
return func, func_name
Loading