From 9c84a9a9b342dd59baa5e94df80f7321a118d004 Mon Sep 17 00:00:00 2001 From: Piotr Rzysko Date: Sun, 3 Nov 2024 20:33:34 +0100 Subject: [PATCH] Add metrics for Accumulator and GroupByHash --- .../io/trino/operator/AggregationMetrics.java | 28 ++++++++++++ .../trino/operator/AggregationOperator.java | 11 ++++- .../operator/HashAggregationOperator.java | 10 +++-- .../operator/MeasuredGroupByHashWork.java | 45 +++++++++++++++++++ .../StreamingAggregationOperator.java | 20 +++++++-- .../operator/aggregation/Aggregator.java | 10 ++++- .../aggregation/AggregatorFactory.java | 13 +++--- .../aggregation/GroupedAggregator.java | 10 ++++- .../InMemoryHashAggregationBuilder.java | 23 ++++++---- .../MergingHashAggregationBuilder.java | 10 ++++- .../SpillableHashAggregationBuilder.java | 13 ++++-- .../partial/SkipAggregationBuilder.java | 8 +++- .../aggregation/AggregationTestUtils.java | 17 +++---- .../BenchmarkArrayAggregation.java | 3 +- .../BenchmarkDecimalAggregation.java | 9 ++-- .../BenchmarkGroupedTypedHistogram.java | 3 +- .../aggregation/TestArrayAggregation.java | 5 ++- .../TestDoubleHistogramAggregation.java | 4 +- .../operator/aggregation/TestHistogram.java | 5 ++- .../TestMultimapAggAggregation.java | 13 ++++-- .../TestRealHistogramAggregation.java | 4 +- .../groupby/AggregationTestInput.java | 3 +- ...patialPartitioningInternalAggregation.java | 5 ++- .../plugin/hive/BaseHiveConnectorTest.java | 23 ++++++++++ .../ml/TestEvaluateClassifierPredictions.java | 3 +- .../plugin/ml/TestLearnAggregations.java | 5 ++- 26 files changed, 241 insertions(+), 62 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/operator/MeasuredGroupByHashWork.java diff --git a/core/trino-main/src/main/java/io/trino/operator/AggregationMetrics.java b/core/trino-main/src/main/java/io/trino/operator/AggregationMetrics.java index e8b4cca804c73..dfffc982ca663 100644 --- a/core/trino-main/src/main/java/io/trino/operator/AggregationMetrics.java +++ b/core/trino-main/src/main/java/io/trino/operator/AggregationMetrics.java @@ -16,18 +16,40 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import io.trino.plugin.base.metrics.DurationTiming; import io.trino.plugin.base.metrics.LongCount; import io.trino.spi.metrics.Metric; import io.trino.spi.metrics.Metrics; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + public class AggregationMetrics { @VisibleForTesting static final String INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME = "Input rows processed without partial aggregation enabled"; + private static final String ACCUMULATOR_TIME_METRIC_NAME = "Accumulator update wall time"; + private static final String GROUP_BY_HASH_TIME_METRIC_NAME = "Group by hash update wall time"; + private long accumulatorTimeNanos; + private boolean hasAccumulator; + private long groupByHashTimeNanos; + private boolean hasGroupByHash; private long inputRowsProcessedWithPartialAggregationDisabled; private boolean hasInputRowsProcessedWithPartialAggregationDisabled; + public void recordAccumulatorUpdateTimeSince(long startNanos) + { + accumulatorTimeNanos += System.nanoTime() - startNanos; + hasAccumulator = true; + } + + public void recordGroupByHashUpdateTimeSince(long startNanos) + { + groupByHashTimeNanos += System.nanoTime() - startNanos; + hasGroupByHash = true; + } + public void recordInputRowsProcessedWithPartialAggregationDisabled(long rows) { inputRowsProcessedWithPartialAggregationDisabled += rows; @@ -37,6 +59,12 @@ public void recordInputRowsProcessedWithPartialAggregationDisabled(long rows) public Metrics getMetrics() { ImmutableMap.Builder> builder = ImmutableMap.builder(); + if (hasAccumulator) { + builder.put(ACCUMULATOR_TIME_METRIC_NAME, new DurationTiming(new Duration(accumulatorTimeNanos, NANOSECONDS))); + } + if (hasGroupByHash) { + builder.put(GROUP_BY_HASH_TIME_METRIC_NAME, new DurationTiming(new Duration(groupByHashTimeNanos, NANOSECONDS))); + } if (hasInputRowsProcessedWithPartialAggregationDisabled) { builder.put(INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME, new LongCount(inputRowsProcessedWithPartialAggregationDisabled)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/AggregationOperator.java b/core/trino-main/src/main/java/io/trino/operator/AggregationOperator.java index 042b638fb981f..b35a56a248a05 100644 --- a/core/trino-main/src/main/java/io/trino/operator/AggregationOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/AggregationOperator.java @@ -81,6 +81,7 @@ private enum State private final OperatorContext operatorContext; private final LocalMemoryContext userMemoryContext; private final List aggregates; + private final AggregationMetrics aggregationMetrics = new AggregationMetrics(); private State state = State.NEEDS_INPUT; @@ -90,7 +91,7 @@ public AggregationOperator(OperatorContext operatorContext, List factory.createAggregator(aggregationMetrics)) .collect(toImmutableList()); } @@ -111,6 +112,7 @@ public void finish() @Override public void close() { + updateOperatorMetrics(); userMemoryContext.setBytes(0); } @@ -144,6 +146,7 @@ public void addInput(Page page) public Page getOutput() { if (state != State.HAS_OUTPUT) { + updateOperatorMetrics(); return null; } @@ -162,6 +165,12 @@ public Page getOutput() } state = State.FINISHED; + updateOperatorMetrics(); return pageBuilder.build(); } + + private void updateOperatorMetrics() + { + operatorContext.setLatestMetrics(aggregationMetrics.getMetrics()); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java b/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java index c4451f449431c..964fa2f6d7647 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java @@ -388,7 +388,7 @@ public void addInput(Page page) .map(PartialAggregationController::isPartialAggregationDisabled) .orElse(false); if (step.isOutputPartial() && partialAggregationDisabled) { - aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext); + aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext, aggregationMetrics); } else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. @@ -409,7 +409,8 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { return true; } return operatorContext.isWaitingForMemory().isDone(); - }); + }, + aggregationMetrics); } else { aggregationBuilder = new SpillableHashAggregationBuilder( @@ -424,7 +425,8 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { memoryLimitForMergeWithMemory, spillerFactory, flatHashStrategyCompiler, - typeOperators); + typeOperators, + aggregationMetrics); } // assume initial aggregationBuilder is not full @@ -582,7 +584,7 @@ private Page getGlobalAggregationOutput() } for (AggregatorFactory aggregatorFactory : aggregatorFactories) { - aggregatorFactory.createAggregator().evaluate(output.getBlockBuilder(channel)); + aggregatorFactory.createAggregator(aggregationMetrics).evaluate(output.getBlockBuilder(channel)); channel++; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/MeasuredGroupByHashWork.java b/core/trino-main/src/main/java/io/trino/operator/MeasuredGroupByHashWork.java new file mode 100644 index 0000000000000..32383c5c998d0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/MeasuredGroupByHashWork.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.operator; + +import static java.util.Objects.requireNonNull; + +public class MeasuredGroupByHashWork + implements Work +{ + private final Work delegate; + private final AggregationMetrics metrics; + + public MeasuredGroupByHashWork(Work delegate, AggregationMetrics metrics) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.metrics = requireNonNull(metrics, "metrics is null"); + } + + @Override + public boolean process() + { + long start = System.nanoTime(); + boolean result = delegate.process(); + metrics.recordGroupByHashUpdateTimeSince(start); + return result; + } + + @Override + public T getResult() + { + return delegate.getResult(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/StreamingAggregationOperator.java b/core/trino-main/src/main/java/io/trino/operator/StreamingAggregationOperator.java index e0bf64eb8d302..a86ca48b97d7d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/StreamingAggregationOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/StreamingAggregationOperator.java @@ -23,6 +23,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; +import io.trino.spi.metrics.Metrics; import io.trino.spi.type.Type; import io.trino.sql.gen.JoinCompiler; import io.trino.sql.planner.plan.PlanNodeId; @@ -134,6 +135,7 @@ public Factory duplicate() } private final WorkProcessor pages; + private final AggregationMetrics aggregationMetrics = new AggregationMetrics(); private StreamingAggregationOperator( ProcessorContext processorContext, @@ -151,7 +153,8 @@ private StreamingAggregationOperator( groupByTypes, groupByChannels, aggregatorFactories, - joinCompiler)); + joinCompiler, + aggregationMetrics)); } @Override @@ -160,6 +163,12 @@ public WorkProcessor getOutputPages() return pages; } + @Override + public Metrics getMetrics() + { + return aggregationMetrics.getMetrics(); + } + private static class StreamingAggregation implements Transformation { @@ -168,6 +177,7 @@ private static class StreamingAggregation private final int[] groupByChannels; private final List aggregatorFactories; private final PagesHashStrategy pagesHashStrategy; + private final AggregationMetrics aggregationMetrics; private List aggregates; private final PageBuilder pageBuilder; @@ -180,7 +190,8 @@ private StreamingAggregation( List groupByTypes, List groupByChannels, List aggregatorFactories, - JoinCompiler joinCompiler) + JoinCompiler joinCompiler, + AggregationMetrics aggregationMetrics) { requireNonNull(processorContext, "processorContext is null"); this.userMemoryContext = processorContext.getMemoryTrackingContext().localUserMemoryContext(); @@ -189,7 +200,7 @@ private StreamingAggregation( this.aggregatorFactories = requireNonNull(aggregatorFactories, "aggregatorFactories is null"); this.aggregates = aggregatorFactories.stream() - .map(AggregatorFactory::createAggregator) + .map(factory -> factory.createAggregator(aggregationMetrics)) .collect(toImmutableList()); this.pageBuilder = new PageBuilder(toTypes(groupByTypes, aggregates)); requireNonNull(joinCompiler, "joinCompiler is null"); @@ -200,6 +211,7 @@ private StreamingAggregation( sourceTypes.stream() .map(type -> new ObjectArrayList()) .collect(toImmutableList()), OptionalInt.empty()); + this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null"); } @Override @@ -317,7 +329,7 @@ private void evaluateAndFlushGroup(Page page, int position) } aggregates = aggregatorFactories.stream() - .map(AggregatorFactory::createAggregator) + .map(factory -> factory.createAggregator(aggregationMetrics)) .collect(toImmutableList()); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java index b6b49fb294b2b..b937c677616df 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/Aggregator.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation; import com.google.common.primitives.Ints; +import io.trino.operator.AggregationMetrics; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -36,6 +37,7 @@ public class Aggregator private final int[] inputChannels; private final OptionalInt maskChannel; private final AggregationMaskBuilder maskBuilder; + private final AggregationMetrics metrics; public Aggregator( Accumulator accumulator, @@ -44,7 +46,8 @@ public Aggregator( Type finalType, List inputChannels, OptionalInt maskChannel, - AggregationMaskBuilder maskBuilder) + AggregationMaskBuilder maskBuilder, + AggregationMetrics metrics) { this.accumulator = requireNonNull(accumulator, "accumulator is null"); this.step = requireNonNull(step, "step is null"); @@ -53,6 +56,7 @@ public Aggregator( this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null")); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null"); + this.metrics = requireNonNull(metrics, "metrics is null"); checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } @@ -77,10 +81,14 @@ public void processPage(Page page) if (mask.isSelectNone()) { return; } + long start = System.nanoTime(); accumulator.addInput(arguments, mask); + metrics.recordAccumulatorUpdateTimeSince(start); } else { + long start = System.nanoTime(); accumulator.addIntermediate(page.getBlock(inputChannels[0])); + metrics.recordAccumulatorUpdateTimeSince(start); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java index 057faab35c05d..34caf7e63c33a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; +import io.trino.operator.AggregationMetrics; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.AggregationNode.Step; @@ -57,7 +58,7 @@ public AggregatorFactory( checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } - public Aggregator createAggregator() + public Aggregator createAggregator(AggregationMetrics metrics) { Accumulator accumulator; if (step.isInputRaw()) { @@ -66,10 +67,10 @@ public Aggregator createAggregator() else { accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders); } - return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder()); + return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder(), metrics); } - public GroupedAggregator createGroupedAggregator() + public GroupedAggregator createGroupedAggregator(AggregationMetrics metrics) { GroupedAccumulator accumulator; if (step.isInputRaw()) { @@ -78,10 +79,10 @@ public GroupedAggregator createGroupedAggregator() else { accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } - return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder()); + return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder(), metrics); } - public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel) + public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel, AggregationMetrics metrics) { GroupedAccumulator accumulator; if (step.isInputRaw()) { @@ -90,7 +91,7 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan else { accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } - return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel, accumulatorFactory.createAggregationMaskBuilder()); + return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel, accumulatorFactory.createAggregationMaskBuilder(), metrics); } public boolean isSpillable() diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java index 998a914830d62..18434899ab132 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation; import com.google.common.primitives.Ints; +import io.trino.operator.AggregationMetrics; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -37,6 +38,7 @@ public class GroupedAggregator private final int[] inputChannels; private final OptionalInt maskChannel; private final AggregationMaskBuilder maskBuilder; + private final AggregationMetrics metrics; public GroupedAggregator( GroupedAccumulator accumulator, @@ -45,7 +47,8 @@ public GroupedAggregator( Type finalType, List inputChannels, OptionalInt maskChannel, - AggregationMaskBuilder maskBuilder) + AggregationMaskBuilder maskBuilder, + AggregationMetrics metrics) { this.accumulator = requireNonNull(accumulator, "accumulator is null"); this.step = requireNonNull(step, "step is null"); @@ -54,6 +57,7 @@ public GroupedAggregator( this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null")); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null"); + this.metrics = requireNonNull(metrics, "metrics is null"); checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } @@ -87,10 +91,14 @@ public void processPage(int groupCount, int[] groupIds, Page page) } // Unwrap any LazyBlock values before evaluating the accumulator arguments = arguments.getLoadedPage(); + long start = System.nanoTime(); accumulator.addInput(groupIds, arguments, mask); + metrics.recordAccumulatorUpdateTimeSince(start); } else { + long start = System.nanoTime(); accumulator.addIntermediate(groupIds, page.getBlock(inputChannels[0])); + metrics.recordAccumulatorUpdateTimeSince(start); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index 8ea87ae6025d6..fbe557beb0a39 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -19,8 +19,10 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.trino.array.IntBigArray; +import io.trino.operator.AggregationMetrics; import io.trino.operator.FlatHashStrategyCompiler; import io.trino.operator.GroupByHash; +import io.trino.operator.MeasuredGroupByHashWork; import io.trino.operator.OperatorContext; import io.trino.operator.TransformWork; import io.trino.operator.UpdateMemory; @@ -57,6 +59,7 @@ public class InMemoryHashAggregationBuilder private final boolean partial; private final OptionalLong maxPartialMemory; private final UpdateMemory updateMemory; + private final AggregationMetrics aggregationMetrics; private boolean full; @@ -70,7 +73,8 @@ public InMemoryHashAggregationBuilder( OperatorContext operatorContext, Optional maxPartialMemory, FlatHashStrategyCompiler hashStrategyCompiler, - UpdateMemory updateMemory) + UpdateMemory updateMemory, + AggregationMetrics aggregationMetrics) { this(aggregatorFactories, step, @@ -82,7 +86,8 @@ public InMemoryHashAggregationBuilder( maxPartialMemory, Optional.empty(), hashStrategyCompiler, - updateMemory); + updateMemory, + aggregationMetrics); } public InMemoryHashAggregationBuilder( @@ -96,7 +101,8 @@ public InMemoryHashAggregationBuilder( Optional maxPartialMemory, Optional unspillIntermediateChannelOffset, FlatHashStrategyCompiler hashStrategyCompiler, - UpdateMemory updateMemory) + UpdateMemory updateMemory, + AggregationMetrics aggregationMetrics) { if (hashChannel.isPresent()) { this.groupByOutputTypes = ImmutableList.builderWithExpectedSize(groupByTypes.size() + 1) @@ -131,13 +137,14 @@ public InMemoryHashAggregationBuilder( for (int i = 0; i < aggregatorFactories.size(); i++) { AggregatorFactory accumulatorFactory = aggregatorFactories.get(i); if (unspillIntermediateChannelOffset.isPresent()) { - builder.add(accumulatorFactory.createUnspillGroupedAggregator(step, unspillIntermediateChannelOffset.get() + i)); + builder.add(accumulatorFactory.createUnspillGroupedAggregator(step, unspillIntermediateChannelOffset.get() + i, aggregationMetrics)); } else { - builder.add(accumulatorFactory.createGroupedAggregator()); + builder.add(accumulatorFactory.createGroupedAggregator(aggregationMetrics)); } } groupedAggregators = builder.build(); + this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null"); } @Override @@ -147,10 +154,10 @@ public void close() {} public Work processPage(Page page) { if (groupedAggregators.isEmpty()) { - return groupByHash.addPage(page.getLoadedPage(groupByChannels)); + return new MeasuredGroupByHashWork<>(groupByHash.addPage(page.getLoadedPage(groupByChannels)), aggregationMetrics); } return new TransformWork<>( - groupByHash.getGroupIds(page.getLoadedPage(groupByChannels)), + new MeasuredGroupByHashWork<>(groupByHash.getGroupIds(page.getLoadedPage(groupByChannels)), aggregationMetrics), groupByIdBlock -> { int groupCount = groupByHash.getGroupCount(); for (GroupedAggregator groupedAggregator : groupedAggregators) { @@ -339,7 +346,7 @@ public static List toTypes(List groupByType, List aggregatorFactories, @@ -61,7 +64,8 @@ public MergingHashAggregationBuilder( AggregatedMemoryContext aggregatedMemoryContext, long memoryLimitForMerge, int overwriteIntermediateChannelOffset, - FlatHashStrategyCompiler hashStrategyCompiler) + FlatHashStrategyCompiler hashStrategyCompiler, + AggregationMetrics aggregationMetrics) { ImmutableList.Builder groupByPartialChannels = ImmutableList.builderWithExpectedSize(groupByTypes.size()); for (int i = 0; i < groupByTypes.size(); i++) { @@ -80,6 +84,7 @@ public MergingHashAggregationBuilder( this.memoryLimitForMerge = memoryLimitForMerge; this.overwriteIntermediateChannelOffset = overwriteIntermediateChannelOffset; this.hashStrategyCompiler = hashStrategyCompiler; + this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null"); rebuildHashAggregationBuilder(); } @@ -151,6 +156,7 @@ private void rebuildHashAggregationBuilder() Optional.of(overwriteIntermediateChannelOffset), hashStrategyCompiler, // TODO: merging should also yield on memory reservations - () -> true); + () -> true, + aggregationMetrics); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java index 8edfc891cd5e6..8268ca975d2d7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -18,6 +18,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.trino.memory.context.LocalMemoryContext; +import io.trino.operator.AggregationMetrics; import io.trino.operator.FlatHashStrategyCompiler; import io.trino.operator.MergeHashSort; import io.trino.operator.OperatorContext; @@ -42,6 +43,7 @@ import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.operator.Operator.NOT_BLOCKED; import static java.lang.Math.max; +import static java.util.Objects.requireNonNull; public class SpillableHashAggregationBuilder implements HashAggregationBuilder @@ -65,6 +67,7 @@ public class SpillableHashAggregationBuilder private ListenableFuture spillInProgress = immediateVoidFuture(); private final FlatHashStrategyCompiler hashStrategyCompiler; private final TypeOperators typeOperators; + private final AggregationMetrics aggregationMetrics; // todo get rid of that and only use revocable memory private long emptyHashAggregationBuilderSize; @@ -83,7 +86,8 @@ public SpillableHashAggregationBuilder( DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, FlatHashStrategyCompiler hashStrategyCompiler, - TypeOperators typeOperators) + TypeOperators typeOperators, + AggregationMetrics aggregationMetrics) { this.aggregatorFactories = aggregatorFactories; this.step = step; @@ -99,6 +103,7 @@ public SpillableHashAggregationBuilder( this.spillerFactory = spillerFactory; this.hashStrategyCompiler = hashStrategyCompiler; this.typeOperators = typeOperators; + this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null"); rebuildHashAggregationBuilder(); } @@ -301,7 +306,8 @@ private WorkProcessor mergeSortedPages(WorkProcessor sortedPages, lo operatorContext.aggregateUserMemoryContext(), memoryLimitForMerge, hashAggregationBuilder.getKeyChannels(), - hashStrategyCompiler)); + hashStrategyCompiler, + aggregationMetrics)); return merger.get().buildResult(); } @@ -326,7 +332,8 @@ private void rebuildHashAggregationBuilder() updateMemory(); // TODO: Support GroupByHash yielding in spillable hash aggregation (https://github.com/trinodb/trino/issues/460) return true; - }); + }, + aggregationMetrics); emptyHashAggregationBuilderSize = hashAggregationBuilder.getSizeInMemory(); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java index 8c2891459b04a..8a75492f59b12 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.trino.memory.context.LocalMemoryContext; +import io.trino.operator.AggregationMetrics; import io.trino.operator.CompletedWork; import io.trino.operator.Work; import io.trino.operator.WorkProcessor; @@ -43,6 +44,7 @@ public class SkipAggregationBuilder { private final LocalMemoryContext memoryContext; private final List aggregatorFactories; + private final AggregationMetrics aggregationMetrics; @Nullable private Page currentPage; private final int[] hashChannels; @@ -51,7 +53,8 @@ public SkipAggregationBuilder( List groupByChannels, Optional inputHashChannel, List aggregatorFactories, - LocalMemoryContext memoryContext) + LocalMemoryContext memoryContext, + AggregationMetrics aggregationMetrics) { this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); this.aggregatorFactories = ImmutableList.copyOf(requireNonNull(aggregatorFactories, "aggregatorFactories is null")); @@ -60,6 +63,7 @@ public SkipAggregationBuilder( hashChannels[i] = groupByChannels.get(i); } inputHashChannel.ifPresent(channelIndex -> hashChannels[groupByChannels.size()] = channelIndex); + this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null"); } @Override @@ -130,7 +134,7 @@ private Page buildOutputPage(Page page) // Evaluate each grouped aggregator into its own output block for (int i = 0; i < aggregatorFactories.size(); i++) { - GroupedAggregator groupedAggregator = aggregatorFactories.get(i).createGroupedAggregator(); + GroupedAggregator groupedAggregator = aggregatorFactories.get(i).createGroupedAggregator(aggregationMetrics); groupedAggregator.processPage(positionCount, groupIds, page); BlockBuilder outputBuilder = groupedAggregator.getType().createBlockBuilder(null, positionCount); for (int position = 0; position < positionCount; position++) { diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java index 4221a222fd7fa..35da2f68fec40 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AggregationTestUtils.java @@ -17,6 +17,7 @@ import com.google.common.primitives.Ints; import io.trino.block.BlockAssertions; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -232,7 +233,7 @@ public static Object aggregation(TestingAggregationFunction function, Page... pa private static Object aggregation(TestingAggregationFunction function, int[] args, OptionalInt maskChannel, Page... pages) { - Aggregator aggregator = function.createAggregatorFactory(SINGLE, Ints.asList(args), maskChannel).createAggregator(); + Aggregator aggregator = function.createAggregatorFactory(SINGLE, Ints.asList(args), maskChannel).createAggregator(new AggregationMetrics()); for (Page page : pages) { if (page.getPositionCount() > 0) { aggregator.processPage(page); @@ -269,16 +270,16 @@ public static Object partialAggregation(TestingAggregationFunction function, Pag private static Object partialAggregation(TestingAggregationFunction function, int[] args, Page... pages) { AggregatorFactory finalAggregatorFactory = function.createAggregatorFactory(FINAL, Ints.asList(0), OptionalInt.empty()); - Aggregator finalAggregator = finalAggregatorFactory.createAggregator(); + Aggregator finalAggregator = finalAggregatorFactory.createAggregator(new AggregationMetrics()); // Test handling of empty intermediate blocks AggregatorFactory partialAggregatorFactory = function.createAggregatorFactory(PARTIAL, Ints.asList(args), OptionalInt.empty()); - Block emptyBlock = getIntermediateBlock(function.getIntermediateType(), partialAggregatorFactory.createAggregator()); + Block emptyBlock = getIntermediateBlock(function.getIntermediateType(), partialAggregatorFactory.createAggregator(new AggregationMetrics())); finalAggregator.processPage(new Page(emptyBlock)); for (Page page : pages) { - Aggregator partialAggregation = partialAggregatorFactory.createAggregator(); + Aggregator partialAggregation = partialAggregatorFactory.createAggregator(new AggregationMetrics()); if (page.getPositionCount() > 0) { partialAggregation.processPage(page); } @@ -318,7 +319,7 @@ private static Object groupedAggregation(BiFunction isE public static Object groupedAggregation(TestingAggregationFunction function, int[] args, Page... pages) { - GroupedAggregator groupedAggregator = function.createAggregatorFactory(SINGLE, Ints.asList(args), OptionalInt.empty()).createGroupedAggregator(); + GroupedAggregator groupedAggregator = function.createAggregatorFactory(SINGLE, Ints.asList(args), OptionalInt.empty()).createGroupedAggregator(new AggregationMetrics()); for (Page page : pages) { groupedAggregator.processPage(0, createGroupByIdBlock(0, page.getPositionCount()), page); } @@ -357,16 +358,16 @@ private static Object groupedPartialAggregation(BiFunction expectedValues = new ArrayList<>(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleHistogramAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleHistogramAggregation.java index e47a98ba8cf52..326313fa1837d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleHistogramAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDoubleHistogramAggregation.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; @@ -117,7 +118,8 @@ public void testBadNumberOfBuckets() private Aggregator getAggregator(Step step) { - return function.createAggregatorFactory(step, step.isInputRaw() ? ImmutableList.of(0, 1, 2) : ImmutableList.of(0), OptionalInt.empty()).createAggregator(); + return function.createAggregatorFactory(step, step.isInputRaw() ? ImmutableList.of(0, 1, 2) : ImmutableList.of(0), OptionalInt.empty()) + .createAggregator(new AggregationMetrics()); } private static Map extractSingleValue(Block block) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java index 04989c974320e..d82c10b4ff844 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestHistogram.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.operator.aggregation.groupby.AggregationTestInput; import io.trino.operator.aggregation.groupby.AggregationTestInputBuilder; import io.trino.operator.aggregation.groupby.AggregationTestOutput; @@ -236,7 +237,7 @@ public void testEmptyHistogramOutputsNull() { TestingAggregationFunction function = getInternalDefaultVarCharAggregation(); GroupedAggregator groupedAggregator = function.createAggregatorFactory(SINGLE, Ints.asList(new int[] {}), OptionalInt.empty()) - .createGroupedAggregator(); + .createGroupedAggregator(new AggregationMetrics()); BlockBuilder blockBuilder = function.getFinalType().createBlockBuilder(null, 1000); groupedAggregator.evaluate(0, blockBuilder); @@ -292,7 +293,7 @@ private static void testManyValuesInducingRehash(TestingAggregationFunction aggr int itemCount = 30; Random random = new Random(); GroupedAggregator groupedAggregator = aggregationFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty()) - .createGroupedAggregator(); + .createGroupedAggregator(new AggregationMetrics()); for (int j = 0; j < numGroups; j++) { Map expectedValues = new HashMap<>(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java index 301f97d7a3c18..ef8a843a1bfc6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java @@ -19,6 +19,7 @@ import com.google.common.primitives.Ints; import io.trino.RowPageBuilder; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.operator.aggregation.groupby.AggregationTestInput; import io.trino.operator.aggregation.groupby.AggregationTestInputBuilder; import io.trino.operator.aggregation.groupby.AggregationTestOutput; @@ -126,7 +127,8 @@ public void testDoubleRowMap() public void testMultiplePages() { TestingAggregationFunction aggFunction = getAggregationFunction(BIGINT, BIGINT); - GroupedAggregator groupedAggregator = aggFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createGroupedAggregator(); + GroupedAggregator groupedAggregator = aggFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()) + .createGroupedAggregator(new AggregationMetrics()); testMultimapAggWithGroupBy(aggFunction, groupedAggregator, 0, BIGINT, ImmutableList.of(1L, 1L), BIGINT, ImmutableList.of(2L, 3L)); } @@ -135,7 +137,8 @@ public void testMultiplePages() public void testMultiplePagesAndGroups() { TestingAggregationFunction aggFunction = getAggregationFunction(BIGINT, BIGINT); - GroupedAggregator groupedAggregator = aggFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createGroupedAggregator(); + GroupedAggregator groupedAggregator = aggFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()) + .createGroupedAggregator(new AggregationMetrics()); testMultimapAggWithGroupBy(aggFunction, groupedAggregator, 0, BIGINT, ImmutableList.of(1L, 1L), BIGINT, ImmutableList.of(2L, 3L)); testMultimapAggWithGroupBy(aggFunction, groupedAggregator, 300, BIGINT, ImmutableList.of(7L, 7L), BIGINT, ImmutableList.of(8L, 9L)); @@ -145,7 +148,8 @@ public void testMultiplePagesAndGroups() public void testManyValues() { TestingAggregationFunction aggFunction = getAggregationFunction(BIGINT, BIGINT); - GroupedAggregator groupedAggregator = aggFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createGroupedAggregator(); + GroupedAggregator groupedAggregator = aggFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()) + .createGroupedAggregator(new AggregationMetrics()); int numGroups = 30000; int numKeys = 10; @@ -171,7 +175,8 @@ public void testManyValues() public void testEmptyStateOutputIsNull() { TestingAggregationFunction aggregationFunction = getAggregationFunction(BIGINT, BIGINT); - GroupedAggregator groupedAggregator = aggregationFunction.createAggregatorFactory(SINGLE, Ints.asList(), OptionalInt.empty()).createGroupedAggregator(); + GroupedAggregator groupedAggregator = aggregationFunction.createAggregatorFactory(SINGLE, Ints.asList(), OptionalInt.empty()) + .createGroupedAggregator(new AggregationMetrics()); BlockBuilder blockBuilder = aggregationFunction.getFinalType().createBlockBuilder(null, 1); groupedAggregator.evaluate(0, blockBuilder); assertThat(blockBuilder.build().isNull(0)).isTrue(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestRealHistogramAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestRealHistogramAggregation.java index 0e890ebe60282..caa9176f8f055 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestRealHistogramAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestRealHistogramAggregation.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; @@ -94,7 +95,8 @@ public void testMerge() private Aggregator createAggregator(Step step) { - return function.createAggregatorFactory(step, step.isInputRaw() ? ImmutableList.of(0, 1, 2) : ImmutableList.of(0), OptionalInt.empty()).createAggregator(); + return function.createAggregatorFactory(step, step.isInputRaw() ? ImmutableList.of(0, 1, 2) : ImmutableList.of(0), OptionalInt.empty()) + .createAggregator(new AggregationMetrics()); } @Test diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/AggregationTestInput.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/AggregationTestInput.java index 592fa0b3ee619..e6309952c1d0d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/AggregationTestInput.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/groupby/AggregationTestInput.java @@ -15,6 +15,7 @@ package io.trino.operator.aggregation.groupby; import com.google.common.primitives.Ints; +import io.trino.operator.AggregationMetrics; import io.trino.operator.aggregation.AggregationTestUtils; import io.trino.operator.aggregation.GroupedAggregator; import io.trino.operator.aggregation.TestingAggregationFunction; @@ -66,6 +67,6 @@ private Page[] getPages() public GroupedAggregator createGroupedAggregator() { return function.createAggregatorFactory(SINGLE, Ints.asList(args), OptionalInt.empty()) - .createGroupedAggregator(); + .createGroupedAggregator(new AggregationMetrics()); } } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java index 7b96173c64dbe..f33d93858cb26 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java @@ -23,6 +23,7 @@ import io.trino.geospatial.KdbTreeUtils; import io.trino.geospatial.Rectangle; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.operator.aggregation.Aggregator; import io.trino.operator.aggregation.AggregatorFactory; import io.trino.operator.aggregation.GroupedAggregator; @@ -82,12 +83,12 @@ public void test(int partitionCount) AggregatorFactory aggregatorFactory = function.createAggregatorFactory(SINGLE, Ints.asList(0, 1), OptionalInt.empty()); Page page = new Page(geometryBlock, partitionCountBlock); - Aggregator aggregator = aggregatorFactory.createAggregator(); + Aggregator aggregator = aggregatorFactory.createAggregator(new AggregationMetrics()); aggregator.processPage(page); String aggregation = (String) BlockAssertions.getOnlyValue(function.getFinalType(), getFinalBlock(function.getFinalType(), aggregator)); assertThat(aggregation).isEqualTo(expectedValue); - GroupedAggregator groupedAggregator = aggregatorFactory.createGroupedAggregator(); + GroupedAggregator groupedAggregator = aggregatorFactory.createGroupedAggregator(new AggregationMetrics()); groupedAggregator.processPage(0, createGroupByIdBlock(0, page.getPositionCount()), page); String groupValue = (String) getGroupValue(function.getFinalType(), groupedAggregator, 0); assertThat(groupValue).isEqualTo(expectedValue); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index e0839a7b1b90e..2bf0ac31c5b95 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -9069,6 +9069,29 @@ public void testExplainAnalyzeFilterProjectWallTime() "'Projection CPU time' = \\{duration=.*}"); } + @Test + public void testExplainAnalyzeAccumulatorUpdateWallTime() + { + assertExplainAnalyze( + "EXPLAIN ANALYZE VERBOSE SELECT count(*) FROM nation", + "'Accumulator update wall time' = \\{duration=.*}"); + assertExplainAnalyze( + "EXPLAIN ANALYZE VERBOSE SELECT name, (SELECT max(name) FROM region WHERE regionkey > nation.regionkey) FROM nation", + "'Accumulator update wall time' = \\{duration=.*}"); + } + + @Test + public void testExplainAnalyzeGroupByHashUpdateWallTime() + { + assertExplainAnalyze( + "EXPLAIN ANALYZE VERBOSE SELECT nationkey FROM nation GROUP BY nationkey", + "'Group by hash update wall time' = \\{duration=.*}"); + assertExplainAnalyze( + "EXPLAIN ANALYZE VERBOSE SELECT count(*), nationkey FROM nation GROUP BY nationkey", + "'Accumulator update wall time' = \\{duration=.*}", + "'Group by hash update wall time' = \\{duration=.*}"); + } + @Test public void testCreateAcidTableUnsupported() { diff --git a/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestEvaluateClassifierPredictions.java b/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestEvaluateClassifierPredictions.java index bb58185011d3c..fbe1514cd1c5d 100644 --- a/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestEvaluateClassifierPredictions.java +++ b/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestEvaluateClassifierPredictions.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import io.trino.RowPageBuilder; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.operator.aggregation.Aggregator; import io.trino.operator.aggregation.TestingAggregationFunction; import io.trino.spi.Page; @@ -41,7 +42,7 @@ public void testEvaluateClassifierPredictions() { TestingFunctionResolution functionResolution = new TestingFunctionResolution(extractFunctions(new MLPlugin().getFunctions())); TestingAggregationFunction aggregation = functionResolution.getAggregateFunction("evaluate_classifier_predictions", fromTypes(BIGINT, BIGINT)); - Aggregator aggregator = aggregation.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createAggregator(); + Aggregator aggregator = aggregation.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createAggregator(new AggregationMetrics()); aggregator.processPage(getPage()); BlockBuilder finalOut = VARCHAR.createBlockBuilder(null, 1); aggregator.evaluate(finalOut); diff --git a/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestLearnAggregations.java b/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestLearnAggregations.java index 7707328c68f9d..24d70865c584b 100644 --- a/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestLearnAggregations.java +++ b/plugin/trino-ml/src/test/java/io/trino/plugin/ml/TestLearnAggregations.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.RowPageBuilder; import io.trino.metadata.TestingFunctionResolution; +import io.trino.operator.AggregationMetrics; import io.trino.operator.aggregation.Aggregator; import io.trino.operator.aggregation.TestingAggregationFunction; import io.trino.plugin.ml.type.ClassifierParametricType; @@ -69,7 +70,7 @@ public void testLearn() TestingAggregationFunction aggregationFunction = FUNCTION_RESOLUTION.getAggregateFunction( "learn_classifier", fromTypeSignatures(BIGINT.getTypeSignature(), mapType(BIGINT.getTypeSignature(), DOUBLE.getTypeSignature()))); - assertLearnClassifier(aggregationFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createAggregator()); + assertLearnClassifier(aggregationFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1), OptionalInt.empty()).createAggregator(new AggregationMetrics())); } @Test @@ -78,7 +79,7 @@ public void testLearnLibSvm() TestingAggregationFunction aggregationFunction = FUNCTION_RESOLUTION.getAggregateFunction( "learn_libsvm_classifier", fromTypeSignatures(BIGINT.getTypeSignature(), mapType(BIGINT.getTypeSignature(), DOUBLE.getTypeSignature()), VARCHAR.getTypeSignature())); - assertLearnClassifier(aggregationFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1, 2), OptionalInt.empty()).createAggregator()); + assertLearnClassifier(aggregationFunction.createAggregatorFactory(SINGLE, ImmutableList.of(0, 1, 2), OptionalInt.empty()).createAggregator(new AggregationMetrics())); } private static void assertLearnClassifier(Aggregator aggregator)