Skip to content

Commit

Permalink
Add registerCompanionFunctions and overwrite as parameters in agg reg…
Browse files Browse the repository at this point in the history
…istration (7110)
  • Loading branch information
PHILO-HE authored and FelixYBW committed Nov 13, 2023
1 parent 4100874 commit 7e10929
Show file tree
Hide file tree
Showing 18 changed files with 193 additions and 72 deletions.
9 changes: 7 additions & 2 deletions velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ class BitwiseAggregateBase : public SimpleNumericAggregate<T, T, T> {
};

template <template <typename U> class T>
exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
exec::AggregateRegistrationResult registerBitwise(
const std::string& name,
bool registerCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
Expand Down Expand Up @@ -106,7 +109,9 @@ exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
name,
inputType->kindName());
}
});
},
registerCompanionFunctions,
overwrite);
}

} // namespace facebook::velox::functions::aggregate
11 changes: 8 additions & 3 deletions velox/functions/prestosql/aggregates/BitwiseAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ class BitwiseAndAggregate : public BitwiseAggregateBase<T> {

} // namespace

void registerBitwiseAggregates(const std::string& prefix) {
registerBitwise<BitwiseOrAggregate>(prefix + kBitwiseOr);
registerBitwise<BitwiseAndAggregate>(prefix + kBitwiseAnd);
void registerBitwiseAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
registerBitwise<BitwiseOrAggregate>(
prefix + kBitwiseOr, registerCompanionFunctions, overwrite);
registerBitwise<BitwiseAndAggregate>(
prefix + kBitwiseAnd, registerCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
8 changes: 6 additions & 2 deletions velox/functions/prestosql/aggregates/CountAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ class CountAggregate : public SimpleNumericAggregate<bool, int64_t, int64_t> {
} // namespace

exec::AggregateRegistrationResult registerCountAggregate(
const std::string& prefix) {
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.returnType("bigint")
Expand All @@ -178,7 +180,9 @@ exec::AggregateRegistrationResult registerCountAggregate(
VELOX_CHECK_LE(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<CountAggregate>();
});
},
registerCompanionFunctions,
overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
29 changes: 21 additions & 8 deletions velox/functions/prestosql/aggregates/CovarianceAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ template <
typename TIntermediateInput,
typename TIntermediateResult,
typename TResultAccessor>
exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
exec::AggregateRegistrationResult registerCovariance(
const std::string& name,
bool registerCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures = {
// (double, double) -> double
exec::AggregateFunctionSignatureBuilder()
Expand Down Expand Up @@ -607,37 +610,47 @@ exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
"Unsupported raw input type: {}. Expected DOUBLE or REAL.",
rawInputType->toString())
}
});
},
registerCompanionFunctions,
overwrite);
}

} // namespace

void registerCovarianceAggregates(const std::string& prefix) {
void registerCovarianceAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
registerCovariance<
CovarAccumulator,
CovarIntermediateInput,
CovarIntermediateResult,
CovarPopResultAccessor>(prefix + kCovarPop);
CovarPopResultAccessor>(
prefix + kCovarPop, registerCompanionFunctions, overwrite);
registerCovariance<
CovarAccumulator,
CovarIntermediateInput,
CovarIntermediateResult,
CovarSampResultAccessor>(prefix + kCovarSamp);
CovarSampResultAccessor>(
prefix + kCovarSamp, registerCompanionFunctions, overwrite);
registerCovariance<
CorrAccumulator,
CorrIntermediateInput,
CorrIntermediateResult,
CorrResultAccessor>(prefix + kCorr);
CorrResultAccessor>(
prefix + kCorr, registerCompanionFunctions, overwrite);
registerCovariance<
RegrAccumulator,
RegrIntermediateInput,
RegrIntermediateResult,
RegrInterceptResultAccessor>(prefix + kRegrIntercept);
RegrInterceptResultAccessor>(
prefix + kRegrIntercept, registerCompanionFunctions, overwrite);
registerCovariance<
RegrAccumulator,
RegrIntermediateInput,
RegrIntermediateResult,
RegrSlopeResultAccessor>(prefix + kRegrSlop);
RegrSlopeResultAccessor>(
prefix + kRegrSlop, registerCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
18 changes: 13 additions & 5 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,10 @@ template <
typename TNonNumeric,
template <typename T>
class TNumericN>
exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
exec::AggregateRegistrationResult registerMinMax(
const std::string& name,
bool registerCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.orderableTypeVariable("T")
Expand Down Expand Up @@ -1008,16 +1011,21 @@ exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
inputType->kindName());
}
}
});
},
registerCompanionFunctions,
overwrite);
}

} // namespace

void registerMinMaxAggregates(const std::string& prefix) {
void registerMinMaxAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
registerMinMax<MinAggregate, NonNumericMinAggregate, MinNAggregate>(
prefix + kMin);
prefix + kMin, registerCompanionFunctions, overwrite);
registerMinMax<MaxAggregate, NonNumericMaxAggregate, MaxNAggregate>(
prefix + kMax);
prefix + kMax, registerCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
39 changes: 28 additions & 11 deletions velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ extern exec::AggregateRegistrationResult registerBitwiseXorAggregate(
extern exec::AggregateRegistrationResult registerChecksumAggregate(
const std::string& prefix);
extern exec::AggregateRegistrationResult registerCountAggregate(
const std::string& prefix);
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite);
extern exec::AggregateRegistrationResult registerCountIfAggregate(
const std::string& prefix);
extern exec::AggregateRegistrationResult registerEntropyAggregate(
Expand Down Expand Up @@ -62,30 +64,45 @@ extern exec::AggregateRegistrationResult registerSetUnionAggregate(
const std::string& prefix);

extern void registerApproxDistinctAggregates(const std::string& prefix);
extern void registerBitwiseAggregates(const std::string& prefix);
extern void registerBitwiseAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite);
extern void registerBoolAggregates(const std::string& prefix);
extern void registerCentralMomentsAggregates(const std::string& prefix);
extern void registerCovarianceAggregates(const std::string& prefix);
extern void registerMinMaxAggregates(const std::string& prefix);
extern void registerCovarianceAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite);
extern void registerMinMaxAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite);
extern void registerMinMaxByAggregates(const std::string& prefix);
extern void registerSumAggregate(const std::string& prefix);
extern void registerVarianceAggregates(const std::string& prefix);
extern void registerVarianceAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite);

void registerAllAggregateFunctions(const std::string& prefix) {
void registerAllAggregateFunctions(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
registerApproxDistinctAggregates(prefix);
registerApproxMostFrequentAggregate(prefix);
registerApproxPercentileAggregate(prefix);
registerArbitraryAggregate(prefix);
registerArrayAggAggregate(prefix);
registerAverageAggregate(prefix);
registerBitwiseAggregates(prefix);
registerBitwiseAggregates(prefix, registerCompanionFunctions, overwrite);
registerBitwiseXorAggregate(prefix);
registerBoolAggregates(prefix);
registerCentralMomentsAggregates(prefix);
registerChecksumAggregate(prefix);
registerCountAggregate(prefix);
registerCountAggregate(prefix, registerCompanionFunctions, overwrite);
registerCountIfAggregate(prefix);
registerCovarianceAggregates(prefix);
registerCovarianceAggregates(prefix, registerCompanionFunctions, overwrite);
registerEntropyAggregate(prefix);
registerGeometricMeanAggregate(prefix);
registerHistogramAggregate(prefix);
Expand All @@ -95,13 +112,13 @@ void registerAllAggregateFunctions(const std::string& prefix) {
registerMaxDataSizeForStatsAggregate(prefix);
registerMultiMapAggAggregate(prefix);
registerSumDataSizeForStatsAggregate(prefix);
registerMinMaxAggregates(prefix);
registerMinMaxAggregates(prefix, registerCompanionFunctions, overwrite);
registerMinMaxByAggregates(prefix);
registerReduceAgg(prefix);
registerSetAggAggregate(prefix);
registerSetUnionAggregate(prefix);
registerSumAggregate(prefix);
registerVarianceAggregates(prefix);
registerVarianceAggregates(prefix, registerCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

namespace facebook::velox::aggregate::prestosql {

void registerAllAggregateFunctions(const std::string& prefix = "");
void registerAllAggregateFunctions(
const std::string& prefix = "",
bool registerCompanionFunctions = false,
bool overwrite = false);

} // namespace facebook::velox::aggregate::prestosql
32 changes: 23 additions & 9 deletions velox/functions/prestosql/aggregates/VarianceAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,10 @@ void checkSumCountRowType(
}

template <template <typename TInput> class TClass>
exec::AggregateRegistrationResult registerVariance(const std::string& name) {
exec::AggregateRegistrationResult registerVariance(
const std::string& name,
bool registerCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
std::vector<std::string> inputTypes = {
"smallint", "integer", "bigint", "real", "double"};
Expand Down Expand Up @@ -508,18 +511,29 @@ exec::AggregateRegistrationResult registerVariance(const std::string& name) {
"(count:bigint, mean:double, m2:double) struct");
return std::make_unique<TClass<int64_t>>(resultType);
}
});
},
registerCompanionFunctions,
overwrite);
}

} // namespace

void registerVarianceAggregates(const std::string& prefix) {
registerVariance<StdDevSampAggregate>(prefix + kStdDev);
registerVariance<StdDevPopAggregate>(prefix + kStdDevPop);
registerVariance<StdDevSampAggregate>(prefix + kStdDevSamp);
registerVariance<VarSampAggregate>(prefix + kVariance);
registerVariance<VarPopAggregate>(prefix + kVarPop);
registerVariance<VarSampAggregate>(prefix + kVarSamp);
void registerVarianceAggregates(
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
registerVariance<StdDevSampAggregate>(
prefix + kStdDev, registerCompanionFunctions, overwrite);
registerVariance<StdDevPopAggregate>(
prefix + kStdDevPop, registerCompanionFunctions, overwrite);
registerVariance<StdDevSampAggregate>(
prefix + kStdDevSamp, registerCompanionFunctions, overwrite);
registerVariance<VarSampAggregate>(
prefix + kVariance, registerCompanionFunctions, overwrite);
registerVariance<VarPopAggregate>(
prefix + kVarPop, registerCompanionFunctions, overwrite);
registerVariance<VarSampAggregate>(
prefix + kVarSamp, registerCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
8 changes: 6 additions & 2 deletions velox/functions/sparksql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,10 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
/// REAL | DOUBLE | DOUBLE
/// ALL INTs | DOUBLE | DOUBLE
/// DECIMAL | DECIMAL | DECIMAL
exec::AggregateRegistrationResult registerAverage(const std::string& name) {
exec::AggregateRegistrationResult registerAverage(
const std::string& name,
bool registerCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;

for (const auto& inputType :
Expand Down Expand Up @@ -495,7 +498,8 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) {
}
}
},
/*registerCompanionFunctions*/ true);
registerCompanionFunctions,
overwrite);
}

} // namespace facebook::velox::functions::aggregate::sparksql
5 changes: 4 additions & 1 deletion velox/functions/sparksql/aggregates/AverageAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

namespace facebook::velox::functions::aggregate::sparksql {

exec::AggregateRegistrationResult registerAverage(const std::string& name);
exec::AggregateRegistrationResult registerAverage(
const std::string& name,
bool registerCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
6 changes: 4 additions & 2 deletions velox/functions/sparksql/aggregates/BitwiseXorAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ class BitwiseXorAggregate : public BitwiseAggregateBase<T> {
} // namespace

exec::AggregateRegistrationResult registerBitwiseXorAggregate(
const std::string& prefix) {
const std::string& prefix,
bool registerCompanionFunctions,
bool overwrite) {
return functions::aggregate::registerBitwise<BitwiseXorAggregate>(
prefix + "bit_xor");
prefix + "bit_xor", registerCompanionFunctions, overwrite);
}

} // namespace facebook::velox::functions::aggregate::sparksql
4 changes: 3 additions & 1 deletion velox/functions/sparksql/aggregates/BitwiseXorAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
namespace facebook::velox::functions::aggregate::sparksql {

exec::AggregateRegistrationResult registerBitwiseXorAggregate(
const std::string& name);
const std::string& name,
bool registerCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
Loading

0 comments on commit 7e10929

Please sign in to comment.