Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Apr 24, 2024
1 parent 569ac66 commit 2988acd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 15 deletions.
11 changes: 9 additions & 2 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,15 @@ class Aggregate {
// UDAF.
// @param step The aggregation step.
// @param rawInputType The raw input type of the UDAF.
// @param resultType The result type of the UDAF.
// @param constantInputs Optional constant inputs.
// @param resultType The result type of the current aggregation step.
// @param constantInputs Optional constant input values for aggregate
// function. constantInputs should be empty if there are no constant inputs,
// aligned with inputTypes if there is at least one constant input, with
// non-constant inputs represented as nullptr, and must be instances of
// ConstantVector.
// @param companionStep The step used to register aggregate companion
// functions. kPartial for partial companion function, kIntermediate for merge
// and merge extract companion function.
virtual void initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
Expand Down
16 changes: 15 additions & 1 deletion velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ void AggregateCompanionAdapter::MergeFunction::extractValues(
fn_->extractAccumulators(groups, numGroups, result);
}

void AggregateCompanionAdapter::MergeExtractFunction::initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const facebook::velox::TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> /*companionStep*/) {
fn_->initialize(
step,
rawInputType,
resultType,
constantInputs,
core::AggregationNode::Step::kFinal);
}

void AggregateCompanionAdapter::MergeExtractFunction::extractValues(
char** groups,
int32_t numGroups,
Expand Down Expand Up @@ -275,7 +289,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
rawInputTypes,
outputType,
constantInputs,
core::AggregationNode::Step::kIntermediate);
core::AggregationNode::Step::kFinal);
fn_->initializeNewGroups(groups, allSelectedRange);
fn_->enableValidateIntermediateInputs();
fn_->addIntermediateResults(groups, rows, args, false);
Expand Down
7 changes: 7 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ struct AggregateCompanionAdapter {
const TypePtr& resultType)
: MergeFunction{std::move(fn), resultType} {}

void initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) override;

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override;
};
Expand Down
4 changes: 4 additions & 0 deletions velox/exec/AggregateWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ class AggregateWindowFunction : public exec::WindowFunction {
std::vector<TypePtr> argTypes_;
std::vector<column_index_t> argIndices_;
std::vector<VectorPtr> argVectors_;
// Constant input values for aggregate function. it should be empty if there
// are no constant inputs, aligned with inputTypes if there is at least one
// constant input, with non-constant inputs represented as nullptr, and must
// be instances of ConstantVector.
std::vector<VectorPtr> constantInputs_;

// This is a single aggregate row needed by the aggregate function for its
Expand Down
17 changes: 5 additions & 12 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,17 +520,13 @@ class FunctionStateTestAggregate {
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) {
auto expectedRawInputTypes = {BIGINT(), BIGINT()};
std::vector<TypePtr> expectedRawInputTypes = {BIGINT(), BIGINT()};
auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()});

if constexpr (testCompanion) {
VELOX_CHECK(companionStep.has_value());
if (companionStep.value() == core::AggregationNode::Step::kPartial) {
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
expectedRawInputTypes.begin(),
expectedRawInputTypes.end()));
VELOX_CHECK(rawInputTypes == expectedRawInputTypes);
if (step == core::AggregationNode::Step::kPartial ||
step == core::AggregationNode::Step::kSingle) {
// Only check constant inputs in partial and single step.
Expand All @@ -540,7 +536,8 @@ class FunctionStateTestAggregate {
VELOX_CHECK_NULL(constantInputs[0]);
}
} else if (
companionStep.value() == core::AggregationNode::Step::kIntermediate) {
companionStep.value() == core::AggregationNode::Step::kIntermediate ||
companionStep.value() == core::AggregationNode::Step::kFinal) {
VELOX_CHECK_EQ(rawInputTypes.size(), 1);
VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType));

Expand All @@ -550,11 +547,7 @@ class FunctionStateTestAggregate {
VELOX_FAIL("Unexpected aggregation step");
}
} else {
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
expectedRawInputTypes.begin(),
expectedRawInputTypes.end()));
VELOX_CHECK(rawInputTypes == expectedRawInputTypes);
if (step == core::AggregationNode::Step::kPartial ||
step == core::AggregationNode::Step::kSingle) {
// Only check constant inputs in partial and single step.
Expand Down

0 comments on commit 2988acd

Please sign in to comment.