Skip to content

Commit

Permalink
Add metrics for Accumulator and GroupByHash
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrrzysko committed Nov 3, 2024
1 parent dcd9750 commit 9c84a9a
Show file tree
Hide file tree
Showing 26 changed files with 241 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,40 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.Duration;
import io.trino.plugin.base.metrics.DurationTiming;
import io.trino.plugin.base.metrics.LongCount;
import io.trino.spi.metrics.Metric;
import io.trino.spi.metrics.Metrics;

import static java.util.concurrent.TimeUnit.NANOSECONDS;

public class AggregationMetrics
{
@VisibleForTesting
static final String INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME = "Input rows processed without partial aggregation enabled";
private static final String ACCUMULATOR_TIME_METRIC_NAME = "Accumulator update wall time";
private static final String GROUP_BY_HASH_TIME_METRIC_NAME = "Group by hash update wall time";

private long accumulatorTimeNanos;
private boolean hasAccumulator;
private long groupByHashTimeNanos;
private boolean hasGroupByHash;
private long inputRowsProcessedWithPartialAggregationDisabled;
private boolean hasInputRowsProcessedWithPartialAggregationDisabled;

public void recordAccumulatorUpdateTimeSince(long startNanos)
{
accumulatorTimeNanos += System.nanoTime() - startNanos;
hasAccumulator = true;
}

public void recordGroupByHashUpdateTimeSince(long startNanos)
{
groupByHashTimeNanos += System.nanoTime() - startNanos;
hasGroupByHash = true;
}

public void recordInputRowsProcessedWithPartialAggregationDisabled(long rows)
{
inputRowsProcessedWithPartialAggregationDisabled += rows;
Expand All @@ -37,6 +59,12 @@ public void recordInputRowsProcessedWithPartialAggregationDisabled(long rows)
public Metrics getMetrics()
{
ImmutableMap.Builder<String, Metric<?>> builder = ImmutableMap.builder();
if (hasAccumulator) {
builder.put(ACCUMULATOR_TIME_METRIC_NAME, new DurationTiming(new Duration(accumulatorTimeNanos, NANOSECONDS)));
}
if (hasGroupByHash) {
builder.put(GROUP_BY_HASH_TIME_METRIC_NAME, new DurationTiming(new Duration(groupByHashTimeNanos, NANOSECONDS)));
}
if (hasInputRowsProcessedWithPartialAggregationDisabled) {
builder.put(INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME, new LongCount(inputRowsProcessedWithPartialAggregationDisabled));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ private enum State
private final OperatorContext operatorContext;
private final LocalMemoryContext userMemoryContext;
private final List<Aggregator> aggregates;
private final AggregationMetrics aggregationMetrics = new AggregationMetrics();

private State state = State.NEEDS_INPUT;

Expand All @@ -90,7 +91,7 @@ public AggregationOperator(OperatorContext operatorContext, List<AggregatorFacto
this.userMemoryContext = operatorContext.localUserMemoryContext();

aggregates = aggregatorFactories.stream()
.map(AggregatorFactory::createAggregator)
.map(factory -> factory.createAggregator(aggregationMetrics))
.collect(toImmutableList());
}

Expand All @@ -111,6 +112,7 @@ public void finish()
@Override
public void close()
{
updateOperatorMetrics();
userMemoryContext.setBytes(0);
}

Expand Down Expand Up @@ -144,6 +146,7 @@ public void addInput(Page page)
public Page getOutput()
{
if (state != State.HAS_OUTPUT) {
updateOperatorMetrics();
return null;
}

Expand All @@ -162,6 +165,12 @@ public Page getOutput()
}

state = State.FINISHED;
updateOperatorMetrics();
return pageBuilder.build();
}

private void updateOperatorMetrics()
{
operatorContext.setLatestMetrics(aggregationMetrics.getMetrics());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ public void addInput(Page page)
.map(PartialAggregationController::isPartialAggregationDisabled)
.orElse(false);
if (step.isOutputPartial() && partialAggregationDisabled) {
aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext);
aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext, aggregationMetrics);
}
else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
// TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling.
Expand All @@ -409,7 +409,8 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
return true;
}
return operatorContext.isWaitingForMemory().isDone();
});
},
aggregationMetrics);
}
else {
aggregationBuilder = new SpillableHashAggregationBuilder(
Expand All @@ -424,7 +425,8 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
memoryLimitForMergeWithMemory,
spillerFactory,
flatHashStrategyCompiler,
typeOperators);
typeOperators,
aggregationMetrics);
}

// assume initial aggregationBuilder is not full
Expand Down Expand Up @@ -582,7 +584,7 @@ private Page getGlobalAggregationOutput()
}

for (AggregatorFactory aggregatorFactory : aggregatorFactories) {
aggregatorFactory.createAggregator().evaluate(output.getBlockBuilder(channel));
aggregatorFactory.createAggregator(aggregationMetrics).evaluate(output.getBlockBuilder(channel));
channel++;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.trino.operator;

import static java.util.Objects.requireNonNull;

public class MeasuredGroupByHashWork<T>
implements Work<T>
{
private final Work<T> delegate;
private final AggregationMetrics metrics;

public MeasuredGroupByHashWork(Work<T> delegate, AggregationMetrics metrics)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.metrics = requireNonNull(metrics, "metrics is null");
}

@Override
public boolean process()
{
long start = System.nanoTime();
boolean result = delegate.process();
metrics.recordGroupByHashUpdateTimeSince(start);
return result;
}

@Override
public T getResult()
{
return delegate.getResult();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.metrics.Metrics;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.plan.PlanNodeId;
Expand Down Expand Up @@ -134,6 +135,7 @@ public Factory duplicate()
}

private final WorkProcessor<Page> pages;
private final AggregationMetrics aggregationMetrics = new AggregationMetrics();

private StreamingAggregationOperator(
ProcessorContext processorContext,
Expand All @@ -151,7 +153,8 @@ private StreamingAggregationOperator(
groupByTypes,
groupByChannels,
aggregatorFactories,
joinCompiler));
joinCompiler,
aggregationMetrics));
}

@Override
Expand All @@ -160,6 +163,12 @@ public WorkProcessor<Page> getOutputPages()
return pages;
}

@Override
public Metrics getMetrics()
{
return aggregationMetrics.getMetrics();
}

private static class StreamingAggregation
implements Transformation<Page, Page>
{
Expand All @@ -168,6 +177,7 @@ private static class StreamingAggregation
private final int[] groupByChannels;
private final List<AggregatorFactory> aggregatorFactories;
private final PagesHashStrategy pagesHashStrategy;
private final AggregationMetrics aggregationMetrics;

private List<Aggregator> aggregates;
private final PageBuilder pageBuilder;
Expand All @@ -180,7 +190,8 @@ private StreamingAggregation(
List<Type> groupByTypes,
List<Integer> groupByChannels,
List<AggregatorFactory> aggregatorFactories,
JoinCompiler joinCompiler)
JoinCompiler joinCompiler,
AggregationMetrics aggregationMetrics)
{
requireNonNull(processorContext, "processorContext is null");
this.userMemoryContext = processorContext.getMemoryTrackingContext().localUserMemoryContext();
Expand All @@ -189,7 +200,7 @@ private StreamingAggregation(
this.aggregatorFactories = requireNonNull(aggregatorFactories, "aggregatorFactories is null");

this.aggregates = aggregatorFactories.stream()
.map(AggregatorFactory::createAggregator)
.map(factory -> factory.createAggregator(aggregationMetrics))
.collect(toImmutableList());
this.pageBuilder = new PageBuilder(toTypes(groupByTypes, aggregates));
requireNonNull(joinCompiler, "joinCompiler is null");
Expand All @@ -200,6 +211,7 @@ private StreamingAggregation(
sourceTypes.stream()
.map(type -> new ObjectArrayList<Block>())
.collect(toImmutableList()), OptionalInt.empty());
this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null");
}

@Override
Expand Down Expand Up @@ -317,7 +329,7 @@ private void evaluateAndFlushGroup(Page page, int position)
}

aggregates = aggregatorFactories.stream()
.map(AggregatorFactory::createAggregator)
.map(factory -> factory.createAggregator(aggregationMetrics))
.collect(toImmutableList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator.aggregation;

import com.google.common.primitives.Ints;
import io.trino.operator.AggregationMetrics;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand All @@ -36,6 +37,7 @@ public class Aggregator
private final int[] inputChannels;
private final OptionalInt maskChannel;
private final AggregationMaskBuilder maskBuilder;
private final AggregationMetrics metrics;

public Aggregator(
Accumulator accumulator,
Expand All @@ -44,7 +46,8 @@ public Aggregator(
Type finalType,
List<Integer> inputChannels,
OptionalInt maskChannel,
AggregationMaskBuilder maskBuilder)
AggregationMaskBuilder maskBuilder,
AggregationMetrics metrics)
{
this.accumulator = requireNonNull(accumulator, "accumulator is null");
this.step = requireNonNull(step, "step is null");
Expand All @@ -53,6 +56,7 @@ public Aggregator(
this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null"));
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null");
this.metrics = requireNonNull(metrics, "metrics is null");
checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
}

Expand All @@ -77,10 +81,14 @@ public void processPage(Page page)
if (mask.isSelectNone()) {
return;
}
long start = System.nanoTime();
accumulator.addInput(arguments, mask);
metrics.recordAccumulatorUpdateTimeSince(start);
}
else {
long start = System.nanoTime();
accumulator.addIntermediate(page.getBlock(inputChannels[0]));
metrics.recordAccumulatorUpdateTimeSince(start);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.operator.AggregationMetrics;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.AggregationNode.Step;

Expand Down Expand Up @@ -57,7 +58,7 @@ public AggregatorFactory(
checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
}

public Aggregator createAggregator()
public Aggregator createAggregator(AggregationMetrics metrics)
{
Accumulator accumulator;
if (step.isInputRaw()) {
Expand All @@ -66,10 +67,10 @@ public Aggregator createAggregator()
else {
accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders);
}
return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder());
return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder(), metrics);
}

public GroupedAggregator createGroupedAggregator()
public GroupedAggregator createGroupedAggregator(AggregationMetrics metrics)
{
GroupedAccumulator accumulator;
if (step.isInputRaw()) {
Expand All @@ -78,10 +79,10 @@ public GroupedAggregator createGroupedAggregator()
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder());
return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel, accumulatorFactory.createAggregationMaskBuilder(), metrics);
}

public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel)
public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel, AggregationMetrics metrics)
{
GroupedAccumulator accumulator;
if (step.isInputRaw()) {
Expand All @@ -90,7 +91,7 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel, accumulatorFactory.createAggregationMaskBuilder());
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel, accumulatorFactory.createAggregationMaskBuilder(), metrics);
}

public boolean isSpillable()
Expand Down
Loading

0 comments on commit 9c84a9a

Please sign in to comment.