Skip to content

Commit

Permalink
Improve Decimal Aggregation State Serializer Performance
Browse files Browse the repository at this point in the history
Avoids creating heap allocated Slices and indirecting through
SliceInput/SliceOutput to serialize and deserialize
LongDecimalWithOverflowState and LongDecimalWithOverflowAndLongState
values since the Slice width is not dynamic but rather fixed to
some number of long fields followed by a 128 bit decimal.
  • Loading branch information
pettyjamesm authored and martint committed Aug 16, 2021
1 parent 5783c8c commit 0215166
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
package io.trino.operator.aggregation.state;

import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand All @@ -28,6 +26,8 @@
public class LongDecimalWithOverflowAndLongStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowAndLongState>
{
private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;

@Override
public Type getSerializedType()
{
Expand All @@ -37,30 +37,33 @@ public Type getSerializedType()
@Override
public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder out)
{
if (state.getLongDecimal() == null) {
Slice decimal = state.getLongDecimal();
if (decimal == null) {
out.appendNull();
}
else {
Slice slice = Slices.allocate(Long.BYTES + Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH);
SliceOutput output = slice.getOutput();

output.writeLong(state.getLong());
output.writeLong(state.getOverflow());
output.writeBytes(state.getLongDecimal());

VARBINARY.writeSlice(out, slice);
long count = state.getLong();
long overflow = state.getOverflow();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal.getLong(0), decimal.getLong(Long.BYTES)));
}
}

@Override
public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongState state)
{
if (!block.isNull(index)) {
SliceInput slice = VARBINARY.getSlice(block, index).getInput();
Slice slice = VARBINARY.getSlice(block, index);
if (slice.length() != SERIALIZED_SIZE) {
throw new IllegalStateException("Unexpected serialized state size: " + slice.length());
}

long count = slice.getLong(0);
long overflow = slice.getLong(Long.BYTES);
Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES * 2), slice.getLong(Long.BYTES * 3));

state.setLong(slice.readLong());
state.setOverflow(slice.readLong());
state.setLongDecimal(Slices.copyOf(slice.readSlice(slice.available())));
state.setLong(count);
state.setOverflow(overflow);
state.setLongDecimal(decimal);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
package io.trino.operator.aggregation.state;

import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand All @@ -28,6 +26,8 @@
public class LongDecimalWithOverflowStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowState>
{
private static final int SERIALIZED_SIZE = Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;

@Override
public Type getSerializedType()
{
Expand All @@ -37,28 +37,30 @@ public Type getSerializedType()
@Override
public void serialize(LongDecimalWithOverflowState state, BlockBuilder out)
{
if (state.getLongDecimal() == null) {
Slice decimal = state.getLongDecimal();
if (decimal == null) {
out.appendNull();
}
else {
Slice slice = Slices.allocate(Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH);
SliceOutput output = slice.getOutput();

output.writeLong(state.getOverflow());
output.writeBytes(state.getLongDecimal());

VARBINARY.writeSlice(out, slice);
long overflow = state.getOverflow();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal.getLong(0), decimal.getLong(Long.BYTES)));
}
}

@Override
public void deserialize(Block block, int index, LongDecimalWithOverflowState state)
{
if (!block.isNull(index)) {
SliceInput slice = VARBINARY.getSlice(block, index).getInput();
Slice slice = VARBINARY.getSlice(block, index);
if (slice.length() != SERIALIZED_SIZE) {
throw new IllegalStateException("Unexpected serialized state size: " + slice.length());
}

long overflow = slice.getLong(0);
Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES), slice.getLong(Long.BYTES * 2));

state.setOverflow(slice.readLong());
state.setLongDecimal(Slices.copyOf(slice.readSlice(slice.available())));
state.setOverflow(overflow);
state.setLongDecimal(decimal);
}
}
}

0 comments on commit 0215166

Please sign in to comment.