diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java index 9b24e6505b17f..e609b1c6e640f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java @@ -14,8 +14,6 @@ package io.trino.operator.aggregation.state; import io.airlift.slice.Slice; -import io.airlift.slice.SliceInput; -import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -28,6 +26,8 @@ public class LongDecimalWithOverflowAndLongStateSerializer implements AccumulatorStateSerializer { + private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH; + @Override public Type getSerializedType() { @@ -37,18 +37,14 @@ public Type getSerializedType() @Override public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder out) { - if (state.getLongDecimal() == null) { + Slice decimal = state.getLongDecimal(); + if (decimal == null) { out.appendNull(); } else { - Slice slice = Slices.allocate(Long.BYTES + Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH); - SliceOutput output = slice.getOutput(); - - output.writeLong(state.getLong()); - output.writeLong(state.getOverflow()); - output.writeBytes(state.getLongDecimal()); - - VARBINARY.writeSlice(out, slice); + long count = state.getLong(); + long overflow = state.getOverflow(); + VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal.getLong(0), decimal.getLong(Long.BYTES))); } } @@ -56,11 +52,18 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongState state) { if (!block.isNull(index)) { - SliceInput slice = VARBINARY.getSlice(block, index).getInput(); + Slice slice = VARBINARY.getSlice(block, index); + if (slice.length() != SERIALIZED_SIZE) { + throw new IllegalStateException("Unexpected serialized state size: " + slice.length()); + } + + long count = slice.getLong(0); + long overflow = slice.getLong(Long.BYTES); + Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES * 2), slice.getLong(Long.BYTES * 3)); - state.setLong(slice.readLong()); - state.setOverflow(slice.readLong()); - state.setLongDecimal(Slices.copyOf(slice.readSlice(slice.available()))); + state.setLong(count); + state.setOverflow(overflow); + state.setLongDecimal(decimal); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java index d3f7955903bc5..ef7c14a9d5e06 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java @@ -14,8 +14,6 @@ package io.trino.operator.aggregation.state; import io.airlift.slice.Slice; -import io.airlift.slice.SliceInput; -import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -28,6 +26,8 @@ public class LongDecimalWithOverflowStateSerializer implements AccumulatorStateSerializer { + private static final int SERIALIZED_SIZE = Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH; + @Override public Type getSerializedType() { @@ -37,17 +37,13 @@ public Type getSerializedType() @Override public void serialize(LongDecimalWithOverflowState state, BlockBuilder out) { - if (state.getLongDecimal() == null) { + Slice decimal = state.getLongDecimal(); + if (decimal == null) { out.appendNull(); } else { - Slice slice = Slices.allocate(Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH); - SliceOutput output = slice.getOutput(); - - output.writeLong(state.getOverflow()); - output.writeBytes(state.getLongDecimal()); - - VARBINARY.writeSlice(out, slice); + long overflow = state.getOverflow(); + VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal.getLong(0), decimal.getLong(Long.BYTES))); } } @@ -55,10 +51,16 @@ public void serialize(LongDecimalWithOverflowState state, BlockBuilder out) public void deserialize(Block block, int index, LongDecimalWithOverflowState state) { if (!block.isNull(index)) { - SliceInput slice = VARBINARY.getSlice(block, index).getInput(); + Slice slice = VARBINARY.getSlice(block, index); + if (slice.length() != SERIALIZED_SIZE) { + throw new IllegalStateException("Unexpected serialized state size: " + slice.length()); + } + + long overflow = slice.getLong(0); + Slice decimal = Slices.wrappedLongArray(slice.getLong(Long.BYTES), slice.getLong(Long.BYTES * 2)); - state.setOverflow(slice.readLong()); - state.setLongDecimal(Slices.copyOf(slice.readSlice(slice.available()))); + state.setOverflow(overflow); + state.setLongDecimal(decimal); } } }