Skip to content

Commit

Permalink
Add function state
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Mar 20, 2024
1 parent ca0f4a9 commit c7915ec
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 80 deletions.
116 changes: 80 additions & 36 deletions velox/docs/develop/aggregate-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,29 @@ A simple aggregation function is implemented as a class as the following.
using IntermediateType = Array<Generic<T1>>;
using OutputType = Array<Generic<T1>>;

// If UDAF does not require the use of FunctionState, it is necessary
// to declare an empty FunctionState struct.
struct FunctionState {
// Optional.
TypePtr resultType;
};

// Optional. Used only when the UDAF needs to use FunctionState.
static void initialize(
FunctionState& state,
const std::vector<TypePtr>& rawInputTypes,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
state.resultType = resultType;
}

// Optional. Default is true.
static constexpr bool default_null_behavior_ = false;

// Optional.
static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in);
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in);

struct AccumulatorType { ... };
};
Expand All @@ -169,6 +185,15 @@ function's argument type(s) wrapped in a Row<> even if the function only takes
one argument. This is needed for the SimpleAggregateAdapter to parse input
types for arbitrary aggregation functions properly.

A FunctionState struct needs to be declared in the simple aggregation function
class, it is used to hold the function-level variables that are typically
computed once and used at every row when adding inputs to accumulators or
extracting values from accumulators. For example, if the UDAF needs to get the
result type or the raw input type of the aggregaiton function, the author can
hold them in the FunctionState struct, and initialize them in the initialize()
method. If the UDAF does not require the use ofFunctionState, it is necessary
to declare an empty FunctionState struct.

The author can define an optional flag `default_null_behavior_` indicating
whether the aggregation function has default-null behavior. This flag is true
by default. Next, the class can have an optional method `toIntermediate()`
Expand Down Expand Up @@ -257,17 +282,21 @@ For aggregaiton functions of default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool aligned_accumulator_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);
explicit AccumulatorType(HashStringAllocator* allocator, const FunctionState& state);

void addInput(HashStringAllocator* allocator, exec::arg_type<T1> value1, ...);
void addInput(
HashStringAllocator* allocator,
exec::arg_type<T1> value1, ...,
const FunctionState& state);

void combine(
HashStringAllocator* allocator,
exec::arg_type<IntermediateType> other);
exec::arg_type<IntermediateType> other,
const FunctionState& state);

bool writeIntermediateResult(exec::out_type<IntermediateType>& out);
bool writeIntermediateResult(exec::out_type<IntermediateType>& out, const FunctionState& state);

bool writeFinalResult(exec::out_type<OutputType>& out);
bool writeFinalResult(exec::out_type<OutputType>& out, const FunctionState& state);

// Optional. Called during destruction.
void destroy(HashStringAllocator* allocator);
Expand Down Expand Up @@ -296,7 +325,8 @@ addInput

This method adds raw input values to *this* accumulator. It receives a
`HashStringAllocator*` followed by `exec::arg_type<T1>`-typed values, one for
each argument type `Ti` wrapped in InputType.
each argument type `Ti` wrapped in InputType. `const FunctionState&` hold the
function-level variables.

With default-null behavior, raw-input rows where at least one column is null are
ignored before `addInput` is called. After `addInput` is called, *this*
Expand All @@ -306,31 +336,32 @@ combine
"""""""

This method adds an input intermediate state to *this* accumulator. It receives
a `HashStringAllocator*` and one `exec::arg_type<IntermediateType>` value. With
default-null behavior, nulls among the input intermediate states are ignored
before `combine` is called. After `combine` is called, *this* accumulator is
assumed to be non-null.
a `HashStringAllocator*` and one `exec::arg_type<IntermediateType>` value.
`const FunctionState&` hold the function-level variables. With default-null
behavior, nulls among the input intermediate states are ignored before `combine`
is called. After `combine` is called, *this* accumulator is assumed to be non-null.

writeIntermediateResult
"""""""""""""""""""""""

This method writes *this* accumulator out to an intermediate state vector. It
has an out-parameter of the type `exec::out_type<IntermediateType>&`. This
method returns true if it writes a non-null value to `out`, or returns false
meaning a null should be written to the intermediate state vector. Accumulators
that are nulls (i.e., no value has been added to them) automatically become
nulls in the intermediate state vector without `writeIntermediateResult` being
called.
has an out-parameter of the type `exec::out_type<IntermediateType>&`.
`const FunctionState&` hold the function-level variables. This method returns
true if it writes a non-null value to `out`, or returns false meaning a null
should be written to the intermediate state vector. Accumulators that are
nulls (i.e., no value has been added to them) automatically become nulls in
the intermediate state vector without `writeIntermediateResult` being called.

writeFinalResult
""""""""""""""""

This method writes *this* accumulator out to a final result vector. It
has an out-parameter of the type `exec::out_type<OutputType>&`. This
method returns true if it writes a non-null value to `out`, or returns false
meaning a null should be written to the final result vector. Accumulators
that are nulls (i.e., no value has been added to them) automatically become
nulls in the final result vector without `writeFinalResult` being called.
has an out-parameter of the type `exec::out_type<OutputType>&`.
`const FunctionState&` hold the function-level variables. This method returns
true if it writes a non-null value to `out`, or returns false meaning a null
should be written to the final result vector. Accumulators that are
nulls (i.e., no value has been added to them) automatically become nulls in the
final result vector without `writeFinalResult` being called.

AccumulatorType of Non-Default-Null Behavior
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -355,15 +386,25 @@ For aggregaiton functions of non-default-null behavior, the author defines an

explicit AccumulatorType(HashStringAllocator* allocator);

bool addInput(HashStringAllocator* allocator, exec::optional_arg_type<T1> value1, ...);
bool addInput(
HashStringAllocator* allocator,
exec::optional_arg_type<T1> value1, ...,
const FunctionState& state);

bool combine(
HashStringAllocator* allocator,
exec::optional_arg_type<IntermediateType> other);
exec::optional_arg_type<IntermediateType> other,
const FunctionState& state);

bool writeIntermediateResult(bool nonNullGroup, exec::out_type<IntermediateType>& out);
bool writeIntermediateResult(
bool nonNullGroup,
exec::out_type<IntermediateType>& out,
const FunctionState& state);

bool writeFinalResult(bool nonNullGroup, exec::out_type<OutputType>& out);
bool writeFinalResult(
bool nonNullGroup,
exec::out_type<OutputType>& out,
const FunctionState& state);

// Optional.
void destroy(HashStringAllocator* allocator);
Expand All @@ -384,7 +425,7 @@ addInput

This method receives a `HashStringAllocator*` followed by
`exec::optional_arg_type<T1>` values, one for each argument type `Ti` wrapped
in InputType.
in InputType. `const FunctionState&` hold the function-level variables.

This method is called on all raw-input rows even if some columns may be null.
It returns a boolean meaning whether *this* accumulator is non-null after the
Expand All @@ -397,26 +438,29 @@ combine
"""""""

This method receives a `HashStringAllocator*` and an
`exec::optional_arg_type<IntermediateType>` value. This method is called on
all intermediate states even if some are nulls. Same as `addInput`, this method
returns a boolean meaning whether *this* accumulator is non-null after the call.
`exec::optional_arg_type<IntermediateType>` value. `const FunctionState&` hold
the function-level variables.This method is called on all intermediate states
even if some are nulls. Same as `addInput`, this method returns a boolean
meaning whether *this* accumulator is non-null after the call.

writeIntermediateResult
"""""""""""""""""""""""

This method has an out-parameter of the type `exec::out_type<IntermediateType>&`
and a boolean flag `nonNullGroup` indicating whether *this* accumulator is
non-null. This method returns true if it writes a non-null value to `out`, or
return false meaning a null should be written to the intermediate state vector.
non-null. `const FunctionState&` hold the function-level variables. This method
returns true if it writes a non-null value to `out`, or return false meaning a
null should be written to the intermediate state vector.

writeFinalResult
""""""""""""""""

This method writes *this* accumulator out to a final result vector. It has an
out-parameter of the type `exec::out_type<OutputType>&` and a boolean flag
`nonNullGroup` indicating whether *this* accumulator is non-null. This method
returns true if it writes a non-null value to `out`, or return false meaning a
null should be written to the final result vector.
`nonNullGroup` indicating whether *this* accumulator is non-null.
`const FunctionState&` hold the function-level variables.This method returns
true if it writes a non-null value to `out`, or return false meaning a null
should be written to the final result vector.

Limitations
^^^^^^^^^^^
Expand Down
10 changes: 10 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ class Aggregate {
setOffsetsInternal(offset, nullByte, nullMask, rowSizeOffset);
}

// Initialize the function-level state of the simple function interface for
// UDAF.
// @param rawInputType The raw input type of the UDAF.
// @param resultType The result type of the UDAF.
// @param constantInputs Optional constant inputs.
virtual void initialize(
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {}

// Initializes null flags and accumulators for newly encountered groups. This
// function should be called only once for each group.
//
Expand Down
16 changes: 16 additions & 0 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ void AggregateCompanionFunctionBase::clearInternal() {
fn_->clear();
}

void AggregateCompanionFunctionBase::initialize(
const std::vector<TypePtr>& rawInputType,
const facebook::velox::TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
fn_->initialize(rawInputType, resultType, constantInputs);
}

void AggregateCompanionFunctionBase::initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) {
Expand Down Expand Up @@ -208,6 +215,15 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
// Perform per-row aggregation.
std::vector<vector_size_t> allSelectedRange;
rows.applyToSelected([&](auto row) { allSelectedRange.push_back(row); });

// Get the raw input types.
std::vector<TypePtr> rawInputTypes;
rawInputTypes.reserve(args.size());
for (const auto& arg : args) {
rawInputTypes.emplace_back(arg->type());
}

fn_->initialize(rawInputTypes, outputType, {});
fn_->initializeNewGroups(groups, allSelectedRange);
fn_->enableValidateIntermediateInputs();
fn_->addIntermediateResults(groups, rows, args, false);
Expand Down
5 changes: 5 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class AggregateCompanionFunctionBase : public Aggregate {

void destroy(folly::Range<char**> groups) override final;

void initialize(
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) override;

void initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) override final;
Expand Down
2 changes: 2 additions & 0 deletions velox/exec/AggregateInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ std::vector<AggregateInfo> toAggregateInfo(
aggResultType,
operatorCtx.driverCtx()->queryConfig());

info.function->initialize(
aggregate.rawInputTypes, aggResultType, info.constantInputs);
auto lambdas = extractLambdaInputs(aggregate);
if (!lambdas.empty()) {
if (expressionEvaluator == nullptr) {
Expand Down
3 changes: 3 additions & 0 deletions velox/exec/AggregateWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class AggregateWindowFunction : public exec::WindowFunction {
// aggregate_ function object should be initialized.
auto singleGroup = std::vector<vector_size_t>{0};
aggregate_->clear();
aggregate_->initialize(argTypes_, resultType_, argVectors_);
aggregate_->initializeNewGroups(&rawSingleGroupRow_, singleGroup);
aggregateInitialized_ = true;
}
Expand Down Expand Up @@ -330,6 +331,7 @@ class AggregateWindowFunction : public exec::WindowFunction {
// the aggregation based on the frame changes with each row. This would
// require adding new APIs to the Aggregate framework.
aggregate_->clear();
aggregate_->initialize(argTypes_, resultType_, argVectors_);
aggregate_->initializeNewGroups(&rawSingleGroupRow_, kSingleGroup);
aggregateInitialized_ = true;

Expand All @@ -347,6 +349,7 @@ class AggregateWindowFunction : public exec::WindowFunction {
// This value is returned for rows with empty frames.
void computeDefaultAggregateValue(const TypePtr& resultType) {
aggregate_->clear();
aggregate_->initialize(argTypes_, resultType, argVectors_);
aggregate_->initializeNewGroups(
&rawSingleGroupRow_, std::vector<vector_size_t>{0});
aggregateInitialized_ = true;
Expand Down
Loading

0 comments on commit c7915ec

Please sign in to comment.