diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 5501889d7a20f..d89248383b722 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -954,7 +954,9 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( return Status::Invalid("Expected aggregate call ", call.id().uri, "#", call.id().name, " to have at least one argument"); } - case 1: { + default: { + // Handles all arity > 0 + std::shared_ptr options = nullptr; if (arrow_function_name == "stddev" || arrow_function_name == "variance") { // See the following URL for the spec of stddev and variance: @@ -981,21 +983,22 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( } fixed_arrow_func += arrow_function_name; - ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); - const FieldRef* arg_ref = arg.field_ref(); - if (!arg_ref) { - return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", - call.id().name, " to have a direct reference"); + std::vector target; + for (int i = 0; i < call.size(); i++) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); + const FieldRef* arg_ref = arg.field_ref(); + if (!arg_ref) { + return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", + call.id().name, " to have a direct reference"); + } + // Copy arg_ref here because field_ref() return const FieldRef* + target.emplace_back(*arg_ref); } - return compute::Aggregate{std::move(fixed_arrow_func), - options ? std::move(options) : nullptr, *arg_ref, ""}; + options ? std::move(options) : nullptr, + std::move(target), ""}; } - default: - break; } - return Status::NotImplemented( - "Only nullary and unary aggregate functions are currently supported"); }; } diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 2dc0de2d0bfec..29b37da3ac4ef 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -21,11 +21,11 @@ from pyarrow.lib cimport * from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -cdef class ScalarUdfContext(_Weakrefable): +cdef class UdfContext(_Weakrefable): cdef: - CScalarUdfContext c_context + CUdfContext c_context - cdef void init(self, const CScalarUdfContext& c_context) + cdef void init(self, const CUdfContext& c_context) cdef class FunctionOptions(_Weakrefable): diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index a5db5be551456..eaf9d1dfb65cb 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2559,7 +2559,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: deref(pyarrow_unwrap_schema(schema).get()))) -cdef class ScalarUdfContext: +cdef class UdfContext: """ Per-invocation function context/state. @@ -2571,7 +2571,7 @@ cdef class ScalarUdfContext: raise TypeError("Do not call {}'s constructor directly" .format(self.__class__.__name__)) - cdef void init(self, const CScalarUdfContext &c_context): + cdef void init(self, const CUdfContext &c_context): self.c_context = c_context @property @@ -2620,26 +2620,26 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: return f_doc -cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): - cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext) +cdef object box_udf_context(const CUdfContext& c_context): + cdef UdfContext context = UdfContext.__new__(UdfContext) context.init(c_context) return context -cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs): +cdef _udf_callback(user_function, const CUdfContext& c_context, inputs): """ - Helper callback function used to wrap the ScalarUdfContext from Python to C++ + Helper callback function used to wrap the UdfContext from Python to C++ execution. """ - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return user_function(context, *inputs) -def _get_scalar_udf_context(memory_pool, batch_length): - cdef CScalarUdfContext c_context +def _get_udf_context(memory_pool, batch_length): + cdef CUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) c_context.batch_length = batch_length - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return context @@ -2665,11 +2665,19 @@ cdef get_register_tabular_function(): return reg +cdef get_register_aggregate_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterAggregateFunction + return reg + + def register_scalar_function(func, function_name, function_doc, in_types, out_type, func_registry=None): """ Register a user-defined scalar function. + This API is EXPERIMENTAL. + A scalar function is a function that executes elementwise operations on arrays or scalars, i.e. a scalar function must be computed row-by-row with no state where each output row @@ -2684,17 +2692,18 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty func : callable A callable implementing the user-defined function. The first argument is the context argument of type - ScalarUdfContext. + UdfContext. Then, it must take arguments equal to the number of in_types defined. It must return an Array or Scalar matching the out_type. It must return a Scalar if all arguments are scalar, else it must return an Array. To define a varargs function, pass a callable that takes - varargs. The last in_type will be the type of all varargs + *args. The last in_type will be the type of all varargs arguments. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2738,9 +2747,86 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty 21 ] """ - return _register_scalar_like_function(get_register_scalar_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_scalar_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_aggregate_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined non-decomposable aggregate function. + + This API is EXPERIMENTAL. + + A non-decomposable aggregation function is a function that executes + aggregate operations on the whole data that it is aggregating. + In other words, non-decomposable aggregate function cannot be + split into consume/merge/finalize steps. + + This is often used with ordered or segmented aggregation where groups + can be emit before accumulating all of the input data. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + UdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return a Scalar matching the + out_type. + To define a varargs function, pass a callable that takes + *args. The in_type needs to match in type of inputs when + the function gets called. + function_name : str + Name of the function. This name must be unique, i.e., + there should only be one function registered with + this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import numpy as np + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "simple median udf" + >>> func_doc["description"] = "compute median" + >>> + >>> def compute_median(ctx, array): + ... return pa.scalar(np.median(array)) + >>> + >>> func_name = "py_compute_median" + >>> in_types = {"array": pa.int64()} + >>> out_type = pa.float64() + >>> pc.register_aggregate_function(compute_median, func_name, func_doc, + ... in_types, out_type) + >>> + >>> func = pc.get_function(func_name) + >>> func.name + 'py_compute_median' + >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) + >>> answer + + """ + return _register_user_defined_function(get_register_aggregate_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) def register_tabular_function(func, function_name, function_doc, in_types, out_type, @@ -2748,8 +2834,10 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t """ Register a user-defined tabular function. + This API is EXPERIMENTAL. + A tabular function is one accepting a context argument of type - ScalarUdfContext and returning a generator of struct arrays. + UdfContext and returning a generator of struct arrays. The in_types argument must be empty and the out_type argument specifies a schema. Each struct array must have field types correspoding to the schema. @@ -2759,11 +2847,12 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t func : callable A callable implementing the user-defined function. The only argument is the context argument of type - ScalarUdfContext. It must return a callable that + UdfContext. It must return a callable that returns on each invocation a StructArray matching the out_type, where an empty array indicates end. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2783,46 +2872,34 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t with nogil: c_type = make_shared[CStructType](deref(c_schema).fields()) out_type = pyarrow_wrap_data_type(c_type) - return _register_scalar_like_function(get_register_tabular_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_tabular_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) -def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types, - out_type, func_registry=None): +def _register_user_defined_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): """ - Register a user-defined scalar-like function. + Register a user-defined function. - A scalar-like function is a callable accepting a first - context argument of type ScalarUdfContext as well as - possibly additional Arrow arguments, and returning a - an Arrow result appropriate for the kind of function. - A scalar function and a tabular function are examples - for scalar-like functions. - This function is normally not called directly but via - register_scalar_function or register_tabular_function. + This method itself doesn't care about the type of the UDF + (i.e., scalar vs tabular vs aggregate) Parameters ---------- register_func: object - An object holding a CRegisterUdf in a "register_func" attribute. Use - get_register_scalar_function() for a scalar function and - get_register_tabular_function() for a tabular function. + An object holding a CRegisterUdf in a "register_func" attribute. func : callable A callable implementing the user-defined function. - See register_scalar_function and - register_tabular_function for details. - function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). in_types : Dict[str, DataType] A dictionary mapping function argument names to their respective DataType. - See register_scalar_function and - register_tabular_function for details. out_type : DataType Output type of the function. func_registry : FunctionRegistry diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index e299d44c04e16..e92f09354771f 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -84,7 +84,8 @@ call_tabular_function, register_scalar_function, register_tabular_function, - ScalarUdfContext, + register_aggregate_function, + UdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index ef09393cfbd6a..f32cbf01efcd6 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -278,3 +278,59 @@ def unary_function(ctx, x): {"array": pa.int64()}, pa.int64()) return unary_function, func_name + + +@pytest.fixture(scope="session") +def unary_agg_func_fixture(): + """ + Register a unary aggregate function + """ + from pyarrow import compute as pc + import numpy as np + + def func(ctx, x): + return pa.scalar(np.nanmean(x)) + + func_name = "y=avg(x)" + func_doc = {"summary": "y=avg(x)", + "description": "find mean of x"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.float64(), + }, + pa.float64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def varargs_agg_func_fixture(): + """ + Register a unary aggregate function + """ + from pyarrow import compute as pc + import numpy as np + + def func(ctx, *args): + sum = 0.0 + for arg in args: + sum += np.nanmean(arg) + return pa.scalar(sum) + + func_name = "y=sum_mean(x...)" + func_doc = {"summary": "Varargs aggregate", + "description": "Varargs aggregate"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + "y": pa.float64() + }, + pa.float64() + ) + return func, func_name diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 3190877ea0997..86f21f4b528e8 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2775,7 +2775,7 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: int64_t TotalBufferSize(const CRecordBatch& record_batch) int64_t TotalBufferSize(const CTable& table) -ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackUdf(object user_function, const CUdfContext& context, object inputs) cdef extern from "arrow/api.h" namespace "arrow" nogil: @@ -2786,7 +2786,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: - cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": + cdef cppclass CUdfContext" arrow::py::UdfContext": CMemoryPool *pool int64_t batch_length @@ -2805,5 +2805,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: function[CallbackUdf] wrapper, const CUdfOptions& options, CFunctionRegistry* registry) + CStatus RegisterAggregateFunction(PyObject* function, + function[CallbackUdf] wrapper, const CUdfOptions& options, + CFunctionRegistry* registry) + CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 7d63adb8352e8..06c116af820db 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -16,14 +16,16 @@ // under the License. #include "arrow/python/udf.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/python/common.h" +#include "arrow/table.h" #include "arrow/util/checked_cast.h" namespace arrow { +using internal::checked_cast; namespace py { - namespace { struct PythonUdfKernelState : public compute::KernelState { @@ -65,6 +67,26 @@ struct PythonUdfKernelInit { std::shared_ptr function; }; +struct ScalarUdfAggregator : public compute::KernelState { + virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0; + virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0; + virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; +}; + +arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, + const compute::ExecSpan& batch) { + return checked_cast(ctx->state())->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { + return checked_cast(dst)->MergeFrom(ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { + return checked_cast(ctx->state())->Finalize(ctx, out); +} + struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) @@ -82,12 +104,12 @@ struct PythonTableUdfKernelInit { Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0}; + UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; std::unique_ptr function; - RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] { + RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { OwnedRef empty_tuple(PyTuple_New(0)); function = std::make_unique( - cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj())); + cb(function_maker->obj(), udf_context, empty_tuple.obj())); RETURN_NOT_OK(CheckPyError()); return Status::OK(); })); @@ -101,6 +123,105 @@ struct PythonTableUdfKernelInit { UdfWrapperCallback cb; }; +struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { + PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, + std::shared_ptr agg_function, + std::vector> input_types, + std::shared_ptr output_type) + : agg_cb(std::move(agg_cb)), + agg_function(agg_function), + output_type(std::move(output_type)) { + Py_INCREF(agg_function->obj()); + std::vector> fields; + for (size_t i = 0; i < input_types.size(); i++) { + fields.push_back(field("", input_types[i])); + } + input_schema = schema(std::move(fields)); + }; + + ~PythonUdfScalarAggregatorImpl() override { + if (_Py_IsFinalizing()) { + agg_function->detach(); + } + } + + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { + ARROW_ASSIGN_OR_RAISE( + auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + values.push_back(std::move(rb)); + return Status::OK(); + } + + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { + auto& other_values = checked_cast(src).values; + values.insert(values.end(), std::make_move_iterator(other_values.begin()), + std::make_move_iterator(other_values.end())); + + other_values.erase(other_values.begin(), other_values.end()); + return Status::OK(); + } + + Status Finalize(compute::KernelContext* ctx, Datum* out) override { + auto state = + arrow::internal::checked_cast(ctx->state()); + std::shared_ptr& function = state->agg_function; + const int num_args = input_schema->num_fields(); + + // Note: The way that batches are concatenated together + // would result in using double amount of the memory. + // This is OK for now because non decomposable aggregate + // UDF is supposed to be used with segmented aggregation + // where the size of the segment is more or less constant + // so doubling that is not a big deal. This can be also + // improved in the future to use more efficient way to + // concatenate. + ARROW_ASSIGN_OR_RAISE(auto table, + arrow::Table::FromRecordBatches(input_schema, values)); + ARROW_ASSIGN_OR_RAISE(table, table->CombineChunks(ctx->memory_pool())); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; + + if (table->num_rows() == 0) { + return Status::Invalid("Finalized is called with empty inputs"); + } + + RETURN_NOT_OK(SafeCallIntoPython([&] { + std::unique_ptr result; + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = table->column(arg_id)->chunk(0); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + result = std::make_unique( + agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); + } + out->value = std::move(val); + return Status::OK(); + } + return Status::TypeError("Unexpected output type: ", + Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); + })); + return Status::OK(); + } + + UdfWrapperCallback agg_cb; + std::vector> values; + std::shared_ptr agg_function; + std::shared_ptr input_schema; + std::shared_ptr output_type; +}; + struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, std::vector input_types, compute::OutputType output_type) @@ -130,7 +251,7 @@ struct PythonUdf : public PythonUdfKernelState { auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->function; const int num_args = batch.num_values(); - ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length}; + UdfContext udf_context{ctx->memory_pool(), batch.length}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -146,7 +267,7 @@ struct PythonUdf : public PythonUdfKernelState { } } - OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj())); + OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { @@ -234,6 +355,61 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp wrapper, options, registry); } +Status AddAggKernel(std::shared_ptr sig, + compute::KernelInit init, compute::ScalarAggregateFunction* func) { + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), + AggregateUdfConsume, AggregateUdfMerge, + AggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return Status::OK(); +} + +Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (!PyCallable_Check(agg_function)) { + return Status::TypeError("Expected a callable Python object."); + } + + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + + // Py_INCREF here so that once a function is registered + // its refcount gets increased by 1 and doesn't get gced + // if all existing refs are gone + Py_INCREF(agg_function); + + static auto default_scalar_aggregate_options = + compute::ScalarAggregateOptions::Defaults(); + auto aggregate_func = std::make_shared( + options.func_name, options.arity, options.func_doc, + &default_scalar_aggregate_options); + + std::vector input_types; + for (const auto& in_dtype : options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(options.output_type); + + compute::KernelInit init = [agg_wrapper, agg_function, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) + -> Result> { + return std::make_unique( + agg_wrapper, std::make_shared(agg_function), options.input_types, + options.output_type); + }; + + RETURN_NOT_OK(AddAggKernel( + compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + options.arity.is_varargs), + init, aggregate_func.get())); + + RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); + return Status::OK(); +} + Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index b3dcc9ccf44e9..682cbb2ffe8d5 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -43,14 +43,14 @@ struct ARROW_PYTHON_EXPORT UdfOptions { std::shared_ptr output_type; }; -/// \brief A context passed as the first argument of scalar UDF functions. -struct ARROW_PYTHON_EXPORT ScalarUdfContext { +/// \brief A context passed as the first argument of UDF functions. +struct ARROW_PYTHON_EXPORT UdfContext { MemoryPool* pool; int64_t batch_length; }; using UdfWrapperCallback = std::function; + PyObject* user_function, const UdfContext& context, PyObject* inputs)>; /// \brief register a Scalar user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterScalarFunction( @@ -62,6 +62,11 @@ Status ARROW_PYTHON_EXPORT RegisterTabularFunction( PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); +/// \brief register a Aggregate user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterAggregateFunction( + PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, + compute::FunctionRegistry* registry = NULLPTR); + Result> ARROW_PYTHON_EXPORT CallTabularFunction(const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry = NULLPTR); diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index d0da517ea7f12..34faaa157af4d 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -34,9 +34,9 @@ pytestmark = [pytest.mark.dataset, pytest.mark.substrait] -def mock_scalar_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) +def mock_udf_context(batch_length=10): + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) def _write_dummy_data_to_disk(tmpdir, file_name, table): @@ -442,7 +442,7 @@ def table_provider(names, _): function, name = unary_func_fixture expected_tb = test_table.add_column(1, 'y', function( - mock_scalar_udf_context(10), test_table['x'])) + mock_udf_context(10), test_table['x'])) assert res_tb == expected_tb @@ -605,3 +605,151 @@ def table_provider(names, schema): expected = pa.Table.from_pydict({"out": [1, 2, 3]}) assert res_tb == expected + + +def test_aggregate_udf_basic(varargs_agg_func_fixture): + + test_table = pa.Table.from_pydict( + {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4], + "v2": [1.0, 1.0, 1.0, 1.0]} + ) + + def table_provider(names, _): + return test_table + + substrait_query = b""" +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "urn:arrow:substrait_simple_extension_function" + }, + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "y=sum_mean(x...)" + } + } + ], + "relations": [ + { + "root": { + "input": { + "extensionSingle": { + "common": { + "emit": { + "outputMapping": [ + 0, + 1 + ] + } + }, + "input": { + "read": { + "baseSchema": { + "names": [ + "k", + "v1", + "v2", + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["t1"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "segmentKeys": [ + { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + }, + "names": [ + "k", + "v_avg" + ] + } + } + ], +} +""" + buf = pa._substrait._parse_json_plan(substrait_query) + reader = pa.substrait.run_query( + buf, table_provider=table_provider, use_threads=False) + res_tb = reader.read_all() + + expected_tb = pa.Table.from_pydict({ + 'k': [1, 2], + 'v_avg': [2.5, 4.5] + }) + + assert res_tb == expected_tb diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 0f336555f7647..c0cfd3d26e800 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -24,21 +24,82 @@ # UDFs are all tested with a dataset scan pytestmark = pytest.mark.dataset +# For convience, most of the test here doesn't care about udf func docs +empty_udf_doc = {"summary": "", "description": ""} + try: import pyarrow.dataset as ds except ImportError: ds = None -def mock_scalar_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) +def mock_udf_context(batch_length=10): + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) class MyError(RuntimeError): pass +@pytest.fixture(scope="session") +def exception_agg_func_fixture(): + def func(ctx, x): + raise RuntimeError("Oops") + return pa.scalar(len(x)) + + func_name = "y=exception_len(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def wrong_output_dtype_agg_func_fixture(scope="session"): + def func(ctx, x): + return pa.scalar(len(x), pa.int32()) + + func_name = "y=wrong_output_dtype(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def wrong_output_type_agg_func_fixture(scope="session"): + def func(ctx, x): + return len(x) + + func_name = "y=wrong_output_type(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + @pytest.fixture(scope="session") def binary_func_fixture(): """ @@ -228,11 +289,11 @@ def check_scalar_function(func_fixture, if all_scalar: batch_length = 1 - expected_output = function(mock_scalar_udf_context(batch_length), *inputs) func = pc.get_function(name) assert func.name == name result = pc.call_function(name, inputs, length=batch_length) + expected_output = function(mock_udf_context(batch_length), *inputs) assert result == expected_output # At the moment there is an issue when handling nullary functions. # See: ARROW-15286 and ARROW-16290. @@ -593,3 +654,47 @@ def test_udt_datasource1_generator(): def test_udt_datasource1_exception(): with pytest.raises(RuntimeError, match='datasource1_exception'): _test_datasource1_udt(datasource1_exception) + + +def test_agg_basic(unary_agg_func_fixture): + arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + result = pc.call_function("y=avg(x)", [arr]) + expected = pa.scalar(30.0) + assert result == expected + + +def test_agg_empty(unary_agg_func_fixture): + empty = pa.array([], pa.float64()) + + with pytest.raises(pa.ArrowInvalid, match='empty inputs'): + pc.call_function("y=avg(x)", [empty]) + + +def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output datatype"): + pc.call_function("y=wrong_output_dtype(x)", [arr]) + + +def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output type"): + pc.call_function("y=wrong_output_type(x)", [arr]) + + +def test_agg_varargs(varargs_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) + + result = pc.call_function( + "y=sum_mean(x...)", [arr1, arr2] + ) + expected = pa.scalar(33.0) + assert result == expected + + +def test_agg_exception(exception_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + + with pytest.raises(RuntimeError, match='Oops'): + pc.call_function("y=exception_len(x)", [arr])