Skip to content

Commit

Permalink
Use unbound method references for input functions
Browse files Browse the repository at this point in the history
Avoids using a method handle bound to the specific input type for
DecimalSumAggregation/DecimalAverageAggregation input functions,
using a static constant inside of the methods instead since the
input functions are scale-invariant (all values are unscaled) and
method handles carrying arguments bound to specific instances can't
be inlined nearly as well by the JIT.
  • Loading branch information
pettyjamesm authored and martint committed Aug 16, 2021
1 parent c10cd01 commit 5783c8c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.trino.spi.type.Decimals.MAX_PRECISION;
import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION;
import static io.trino.spi.type.Decimals.writeBigDecimal;
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
Expand All @@ -63,9 +65,13 @@ public class DecimalAverageAggregation
{
public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation();

// Constant references for short/long decimal types for use in operations that only manipulate unscaled values
private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0);
private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0);

private static final String NAME = "avg";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", Type.class, LongDecimalWithOverflowAndLongState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", Type.class, LongDecimalWithOverflowAndLongState.class, Block.class, int.class);
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class);

private static final MethodHandle SHORT_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "outputShortDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class);
private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "outputLongDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class);
Expand Down Expand Up @@ -124,7 +130,6 @@ private static InternalAggregationFunction generateAggregation(Type type)
inputFunction = LONG_DECIMAL_INPUT_FUNCTION;
outputFunction = LONG_DECIMAL_OUTPUT_FUNCTION;
}
inputFunction = inputFunction.bindTo(type);
outputFunction = outputFunction.bindTo(type);

AggregationMetadata metadata = new AggregationMetadata(
Expand All @@ -150,7 +155,7 @@ private static List<ParameterMetadata> createInputParameterMetadata(Type type)
return ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(BLOCK_INPUT_CHANNEL, type), new ParameterMetadata(BLOCK_INDEX));
}

public static void inputShortDecimal(Type type, LongDecimalWithOverflowAndLongState state, Block block, int position)
public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position)
{
state.addLong(1); // row counter

Expand All @@ -159,11 +164,11 @@ public static void inputShortDecimal(Type type, LongDecimalWithOverflowAndLongSt
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, UnscaledDecimal128Arithmetic.unscaledDecimal(type.getLong(block, position)), sum);
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, UnscaledDecimal128Arithmetic.unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, position)), sum);
state.addOverflow(overflow);
}

public static void inputLongDecimal(Type type, LongDecimalWithOverflowAndLongState state, Block block, int position)
public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position)
{
state.addLong(1); // row counter

Expand All @@ -172,7 +177,7 @@ public static void inputLongDecimal(Type type, LongDecimalWithOverflowAndLongSta
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, type.getSlice(block, position), sum);
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, LONG_DECIMAL_TYPE.getSlice(block, position), sum);
state.addOverflow(overflow);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.trino.spi.type.Decimals.MAX_PRECISION;
import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION;
import static io.trino.spi.type.TypeSignatureParameter.numericParameter;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwIfOverflows;
Expand All @@ -58,12 +60,16 @@
public class DecimalSumAggregation
extends SqlAggregationFunction
{
// Constant references for short/long decimal types for use in operations that only manipulate unscaled values
private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0);
private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0);

public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation();
private static final String NAME = "sum";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", Type.class, LongDecimalWithOverflowState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputLongDecimal", Type.class, LongDecimalWithOverflowState.class, Block.class, int.class);
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", LongDecimalWithOverflowState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputLongDecimal", LongDecimalWithOverflowState.class, Block.class, int.class);

private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "outputLongDecimal", DecimalType.class, LongDecimalWithOverflowState.class, BlockBuilder.class);
private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "outputLongDecimal", LongDecimalWithOverflowState.class, BlockBuilder.class);

private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalSumAggregation.class, "combine", LongDecimalWithOverflowState.class, LongDecimalWithOverflowState.class);

Expand Down Expand Up @@ -118,10 +124,10 @@ private static InternalAggregationFunction generateAggregation(Type inputType, T
AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
createInputParameterMetadata(inputType),
inputFunction.bindTo(inputType),
inputFunction,
Optional.empty(),
COMBINE_FUNCTION,
LONG_DECIMAL_OUTPUT_FUNCTION.bindTo(outputType),
LONG_DECIMAL_OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor(
stateInterface,
stateSerializer,
Expand All @@ -138,25 +144,25 @@ private static List<ParameterMetadata> createInputParameterMetadata(Type type)
return ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(BLOCK_INPUT_CHANNEL, type), new ParameterMetadata(BLOCK_INDEX));
}

public static void inputShortDecimal(Type type, LongDecimalWithOverflowState state, Block block, int position)
public static void inputShortDecimal(LongDecimalWithOverflowState state, Block block, int position)
{
Slice sum = state.getLongDecimal();
if (sum == null) {
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, unscaledDecimal(type.getLong(block, position)), sum);
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(block, position)), sum);
state.addOverflow(overflow);
}

public static void inputLongDecimal(Type type, LongDecimalWithOverflowState state, Block block, int position)
public static void inputLongDecimal(LongDecimalWithOverflowState state, Block block, int position)
{
Slice sum = state.getLongDecimal();
if (sum == null) {
sum = UnscaledDecimal128Arithmetic.unscaledDecimal();
state.setLongDecimal(sum);
}
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, type.getSlice(block, position), sum);
long overflow = UnscaledDecimal128Arithmetic.addWithOverflow(sum, LONG_DECIMAL_TYPE.getSlice(block, position), sum);
state.addOverflow(overflow);
}

Expand All @@ -175,17 +181,18 @@ public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOv
state.addOverflow(overflow);
}

public static void outputLongDecimal(DecimalType type, LongDecimalWithOverflowState state, BlockBuilder out)
public static void outputLongDecimal(LongDecimalWithOverflowState state, BlockBuilder out)
{
if (state.getLongDecimal() == null) {
Slice decimal = state.getLongDecimal();
if (decimal == null) {
out.appendNull();
}
else {
if (state.getOverflow() != 0) {
throwOverflowException();
}
throwIfOverflows(state.getLongDecimal());
type.writeSlice(out, state.getLongDecimal());
throwIfOverflows(decimal);
LONG_DECIMAL_TYPE.writeSlice(out, decimal);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ private static void addToState(LongDecimalWithOverflowAndLongState state, BigInt
{
BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1);
TYPE.writeSlice(blockBuilder, unscaledDecimal(value));

DecimalAverageAggregation.inputLongDecimal(TYPE, state, blockBuilder.build(), 0);
if (TYPE.isShort()) {
DecimalAverageAggregation.inputShortDecimal(state, blockBuilder.build(), 0);
}
else {
DecimalAverageAggregation.inputLongDecimal(state, blockBuilder.build(), 0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,18 @@ public void testOverflowOnOutput()
addToState(state, TWO.pow(126));

assertEquals(state.getOverflow(), 1);
DecimalSumAggregation.outputLongDecimal(TYPE, state, new VariableWidthBlockBuilder(null, 10, 100));
DecimalSumAggregation.outputLongDecimal(state, new VariableWidthBlockBuilder(null, 10, 100));
}

private static void addToState(LongDecimalWithOverflowState state, BigInteger value)
{
BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1);
TYPE.writeSlice(blockBuilder, unscaledDecimal(value));

DecimalSumAggregation.inputLongDecimal(TYPE, state, blockBuilder.build(), 0);
if (TYPE.isShort()) {
DecimalSumAggregation.inputShortDecimal(state, blockBuilder.build(), 0);
}
else {
DecimalSumAggregation.inputLongDecimal(state, blockBuilder.build(), 0);
}
}
}

0 comments on commit 5783c8c

Please sign in to comment.