From 923fa7a6aaf44ce91e2430bf88fd489929c2aed4 Mon Sep 17 00:00:00 2001 From: Lukasz Stec Date: Tue, 9 Aug 2022 09:56:25 +0200 Subject: [PATCH] Optimize decimal state serializers for small value case Given that many decimal aggregations (sum, avg) stay in the long range, aggregation state serializer can be optimized for this case, limiting the number of bytes per position significantly (3-4X) at the cost of small cpu overhead during serialization and deserialization. --- ...malWithOverflowAndLongStateSerializer.java | 58 ++++++++--- ...ongDecimalWithOverflowStateSerializer.java | 40 +++++--- ...malWithOverflowAndLongStateSerializer.java | 96 +++++++++++++++++++ ...ongDecimalWithOverflowStateSerializer.java | 86 +++++++++++++++++ .../testing/AbstractTestAggregations.java | 12 +++ 5 files changed, 266 insertions(+), 26 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java create mode 100644 core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java index cced798a52591..7556b4a207a8c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java @@ -18,7 +18,6 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.Int128; import io.trino.spi.type.Type; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -26,8 +25,6 @@ public class LongDecimalWithOverflowAndLongStateSerializer implements AccumulatorStateSerializer { - private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + Int128.SIZE; - @Override public Type getSerializedType() { @@ -42,7 +39,27 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou long overflow = state.getOverflow(); long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); - VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal[offset], decimal[offset + 1])); + long[] buffer = new long[4]; + long high = decimal[offset]; + long low = decimal[offset + 1]; + + buffer[0] = low; + buffer[1] = high; + // if high = 0, the count will overwrite it + int countOffset = 1 + (high == 0 ? 0 : 1); + // append count, overflow + buffer[countOffset] = count; + buffer[countOffset + 1] = overflow; + + // cases + // high == 0 (countOffset = 1) + // overflow == 0 & count == 1 -> bufferLength = 1 + // overflow != 0 || count != 1 -> bufferLength = 3 + // high != 0 (countOffset = 2) + // overflow == 0 & count == 1 -> bufferLength = 2 + // overflow != 0 || count != 1 -> bufferLength = 4 + int bufferLength = countOffset + ((overflow == 0 & count == 1) ? 0 : 2); + VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength)); } else { out.appendNull(); @@ -54,20 +71,33 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt { if (!block.isNull(index)) { Slice slice = VARBINARY.getSlice(block, index); - if (slice.length() != SERIALIZED_SIZE) { - throw new IllegalStateException("Unexpected serialized state size: " + slice.length()); - } + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); - long count = slice.getLong(0); - long overflow = slice.getLong(Long.BYTES); + int sliceLength = slice.length(); + long low = slice.getLong(0); + long high = 0; + long overflow = 0; + long count = 1; - state.setLong(count); + switch (sliceLength) { + case 4 * Long.BYTES: + overflow = slice.getLong(Long.BYTES * 3); + count = slice.getLong(Long.BYTES * 2); + // fall through + case 2 * Long.BYTES: + high = slice.getLong(Long.BYTES); + break; + case 3 * Long.BYTES: + overflow = slice.getLong(Long.BYTES * 2); + count = slice.getLong(Long.BYTES); + } + + decimal[offset + 1] = low; + decimal[offset] = high; state.setOverflow(overflow); + state.setLong(count); state.setNotNull(); - long[] decimal = state.getDecimalArray(); - int offset = state.getDecimalArrayOffset(); - decimal[offset] = slice.getLong(Long.BYTES * 2); - decimal[offset + 1] = slice.getLong(Long.BYTES * 3); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java index c8c3498f75388..5ddfb51324ad0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java @@ -18,7 +18,6 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.Int128; import io.trino.spi.type.Type; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -26,8 +25,6 @@ public class LongDecimalWithOverflowStateSerializer implements AccumulatorStateSerializer { - private static final int SERIALIZED_SIZE = Long.BYTES + Int128.SIZE; - @Override public Type getSerializedType() { @@ -41,7 +38,18 @@ public void serialize(LongDecimalWithOverflowState state, BlockBuilder out) long overflow = state.getOverflow(); long[] decimal = state.getDecimalArray(); int offset = state.getDecimalArrayOffset(); - VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal[offset], decimal[offset + 1])); + long[] buffer = new long[3]; + long low = decimal[offset + 1]; + long high = decimal[offset]; + buffer[0] = low; + buffer[1] = high; + buffer[2] = overflow; + // if high == 0 and overflow == 0 we only write low (bufferLength = 1) + // if high != 0 and overflow == 0 we write both low and high (bufferLength = 2) + // if overflow != 0 we write all values (bufferLength = 3) + int decimalsCount = 1 + (high == 0 ? 0 : 1); + int bufferLength = overflow == 0 ? decimalsCount : 3; + VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength)); } else { out.appendNull(); @@ -53,18 +61,26 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta { if (!block.isNull(index)) { Slice slice = VARBINARY.getSlice(block, index); - if (slice.length() != SERIALIZED_SIZE) { - throw new IllegalStateException("Unexpected serialized state size: " + slice.length()); - } + long[] decimal = state.getDecimalArray(); + int offset = state.getDecimalArrayOffset(); - long overflow = slice.getLong(0); + long low = slice.getLong(0); + int sliceLength = slice.length(); + long high = 0; + long overflow = 0; + switch (sliceLength) { + case 3 * Long.BYTES: + overflow = slice.getLong(Long.BYTES * 2); + // fall through + case 2 * Long.BYTES: + high = slice.getLong(Long.BYTES); + } + + decimal[offset + 1] = low; + decimal[offset] = high; state.setOverflow(overflow); state.setNotNull(); - long[] decimal = state.getDecimalArray(); - int offset = state.getDecimalArrayOffset(); - decimal[offset] = slice.getLong(Long.BYTES); - decimal[offset + 1] = slice.getLong(Long.BYTES * 2); } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java new file mode 100644 index 0000000000000..31949aec15591 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlockBuilder; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestLongDecimalWithOverflowAndLongStateSerializer +{ + private static final LongDecimalWithOverflowAndLongStateFactory STATE_FACTORY = new LongDecimalWithOverflowAndLongStateFactory(); + + @Test(dataProvider = "input") + public void testSerde(long low, long high, long overflow, long count, int expectedLength) + { + LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState(); + state.getDecimalArray()[0] = high; + state.getDecimalArray()[1] = low; + state.setOverflow(overflow); + state.setLong(count); + state.setNotNull(); + + LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength); + + assertTrue(outState.isNotNull()); + assertEquals(outState.getDecimalArray()[0], high); + assertEquals(outState.getDecimalArray()[1], low); + assertEquals(outState.getOverflow(), overflow); + assertEquals(outState.getLong(), count); + } + + @Test + public void testNullSerde() + { + // state is created null + LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState(); + + LongDecimalWithOverflowAndLongState outState = roundTrip(state, 0); + + assertFalse(outState.isNotNull()); + } + + private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength) + { + LongDecimalWithOverflowAndLongStateSerializer serializer = new LongDecimalWithOverflowAndLongStateSerializer(); + BlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0); + + serializer.serialize(state, out); + + Block serialized = out.build(); + assertEquals(serialized.getSliceLength(0), expectedLength * Long.BYTES); + LongDecimalWithOverflowAndLongState outState = STATE_FACTORY.createSingleState(); + serializer.deserialize(serialized, 0, outState); + return outState; + } + + @DataProvider + public Object[][] input() + { + return new Object[][] { + {3, 0, 0, 1, 1}, + {3, 5, 0, 1, 2}, + {3, 5, 7, 1, 4}, + {3, 0, 0, 2, 3}, + {3, 5, 0, 2, 4}, + {3, 5, 7, 2, 4}, + {3, 0, 7, 1, 3}, + {3, 0, 7, 2, 3}, + {0, 0, 0, 1, 1}, + {0, 5, 0, 1, 2}, + {0, 5, 7, 1, 4}, + {0, 0, 0, 2, 3}, + {0, 5, 0, 2, 4}, + {0, 5, 7, 2, 4}, + {0, 0, 7, 1, 3}, + {0, 0, 7, 2, 3} + }; + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java new file mode 100644 index 0000000000000..5fcd95891a021 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlockBuilder; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestLongDecimalWithOverflowStateSerializer +{ + private static final LongDecimalWithOverflowStateFactory STATE_FACTORY = new LongDecimalWithOverflowStateFactory(); + + @Test(dataProvider = "input") + public void testSerde(long low, long high, long overflow, int expectedLength) + { + LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState(); + state.getDecimalArray()[0] = high; + state.getDecimalArray()[1] = low; + state.setOverflow(overflow); + state.setNotNull(); + + LongDecimalWithOverflowState outState = roundTrip(state, expectedLength); + + assertTrue(outState.isNotNull()); + assertEquals(outState.getDecimalArray()[0], high); + assertEquals(outState.getDecimalArray()[1], low); + assertEquals(outState.getOverflow(), overflow); + } + + @Test + public void testNullSerde() + { + // state is created null + LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState(); + + LongDecimalWithOverflowState outState = roundTrip(state, 0); + + assertFalse(outState.isNotNull()); + } + + private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength) + { + LongDecimalWithOverflowStateSerializer serializer = new LongDecimalWithOverflowStateSerializer(); + BlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0); + + serializer.serialize(state, out); + + Block serialized = out.build(); + assertEquals(serialized.getSliceLength(0), expectedLength * Long.BYTES); + LongDecimalWithOverflowState outState = STATE_FACTORY.createSingleState(); + serializer.deserialize(serialized, 0, outState); + return outState; + } + + @DataProvider + public Object[][] input() + { + return new Object[][] { + {3, 0, 0, 1}, + {3, 5, 0, 2}, + {3, 5, 7, 3}, + {3, 0, 7, 3}, + {0, 0, 0, 1}, + {0, 5, 0, 2}, + {0, 5, 7, 3}, + {0, 0, 7, 3} + }; + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java index 956481d8e434c..fcd70a9dec406 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java @@ -1400,4 +1400,16 @@ public void testApproxMostFrequentWithStringGroupBy() assertEquals(actual1.getMaterializedRows().get(2).getFields().get(0), "c"); assertEquals(actual1.getMaterializedRows().get(2).getFields().get(1), ImmutableMap.of("C", 2L)); } + + @Test + public void testLongDecimalAggregations() + { + assertQuery(""" + SELECT avg(value_big), sum(value_big), avg(value_small), sum(value_small) + FROM ( + SELECT orderkey as id, CAST(power(2, 65) as DECIMAL(38, 0)) as value_big, CAST(1 as DECIMAL(38, 0)) as value_small + FROM orders + LIMIT 10) + GROUP BY id"""); + } }