Skip to content

Commit

Permalink
Register merge extract companion agg functions without suffix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored and rui-mo committed Jan 18, 2024
1 parent c2d8371 commit 4601a75
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 45 deletions.
79 changes: 38 additions & 41 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
const core::QueryConfig& config)
-> std::unique_ptr<Aggregate> {
if (auto func = getAggregateFunctionEntry(name)) {
core::AggregationNode::Step usedStep{
core::AggregationNode::Step::kPartial};
if (!exec::isRawInput(step)) {
step = core::AggregationNode::Step::kIntermediate;
usedStep = core::AggregationNode::Step::kIntermediate;
}
auto fn = func->factory(step, argTypes, resultType, config);
auto fn =
func->factory(usedStep, argTypes, resultType, config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::PartialFunction>(
Expand Down Expand Up @@ -366,56 +369,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
bool registered = false;
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
registered |=
registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
}

auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
if (mergeExtractSignatures.empty()) {
return false;
return registered;
}

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
[name, mergeExtractFunctionName](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& config)
-> std::unique_ptr<Aggregate> {
const auto& [originalResultType, _] =
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
if (!originalResultType) {
// TODO: limitation -- result type must be resolveable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
originalResultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::MergeExtractFunction>(
std::move(fn), resultType);
}
VELOX_FAIL(
"Original aggregation function {} not found: {}",
name,
mergeExtractFunctionName);
},
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
registered |=
exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
[name, mergeExtractFunctionName](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
resultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::MergeExtractFunction>(
std::move(fn), resultType);
}
VELOX_FAIL(
"Original aggregation function {} not found: {}",
name,
mergeExtractFunctionName);
},
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
return registered;
}

bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ class BloomFilterAggAggregate : public exec::Aggregate {
} // namespace

exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
const std::string& name) {
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
Expand Down Expand Up @@ -318,6 +320,8 @@ exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
const TypePtr& resultType,
const core::QueryConfig& config) -> std::unique_ptr<exec::Aggregate> {
return std::make_unique<BloomFilterAggAggregate>(resultType, config);
});
},
withCompanionFunctions,
overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
namespace facebook::velox::functions::aggregate::sparksql {

exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
const std::string& name);
const std::string& name,
bool withCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
3 changes: 2 additions & 1 deletion velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ void registerAggregateFunctions(
registerFirstLastAggregates(prefix, withCompanionFunctions, overwrite);
registerMinMaxByAggregates(prefix, withCompanionFunctions, overwrite);
registerBitwiseXorAggregate(prefix, withCompanionFunctions, overwrite);
registerBloomFilterAggAggregate(prefix + "bloom_filter_agg");
registerBloomFilterAggAggregate(
prefix + "bloom_filter_agg", withCompanionFunctions, overwrite);
registerAverage(prefix + "avg", withCompanionFunctions, overwrite);
registerSum(prefix + "sum", withCompanionFunctions, overwrite);
}
Expand Down

0 comments on commit 4601a75

Please sign in to comment.