From 05ac13cdfbcacfc660ebe5d37ecefbbccff71df3 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 1 Dec 2021 11:35:44 -0800 Subject: [PATCH] Improve decoding performance for long decimals Avoid using intermediate BigIntegers when decoding long decimals from Parquet and RCBinary --- .../main/java/io/trino/spi/type/Int128.java | 55 +++++++++++ .../java/io/trino/spi/type/TestInt128.java | 92 +++++++++++++++++++ .../io/trino/parquet/ParquetTypeUtils.java | 7 -- .../TupleDomainParquetPredicate.java | 5 +- .../reader/LongDecimalColumnReader.java | 4 +- .../trino/rcfile/binary/DecimalEncoding.java | 11 +-- .../trino/plugin/hive/util/DecimalUtils.java | 8 +- 7 files changed, 156 insertions(+), 26 deletions(-) create mode 100644 core/trino-spi/src/test/java/io/trino/spi/type/TestInt128.java diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java b/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java index e04d2fe4a3ab1..b5c6e2de5d017 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/Int128.java @@ -41,6 +41,61 @@ private Int128(long high, long low) this.low = low; } + /** + * Decode an Int128 from the two's complement big-endian representation. + * + * @param bytes the two's complement big-endian encoding of the number. It must contain at least 1 byte. + * It may contain more than 16 bytes if the leading bytes are not significant (either zeros or -1) + * @throws ArithmeticException if the bytes represent a number outside of the range [-2^127, 2^127 - 1] + */ + public static Int128 fromBigEndian(byte[] bytes) + { + if (bytes.length >= 16) { + int offset = bytes.length - Long.BYTES; + long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset); + + offset -= Long.BYTES; + long high = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset); + + for (int i = 0; i < offset; i++) { + if (bytes[i] != (high >> 63)) { + throw new ArithmeticException("Overflow"); + } + } + + return Int128.valueOf(high, low); + } + else if (bytes.length > 8) { + // read the last 8 bytes into low + int offset = bytes.length - Long.BYTES; + long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset); + + // At this point, we're guaranteed to have between 9 and 15 bytes available. + // Read 8 bytes into high, starting at offset 0. There will be some over-read + // of bytes belonging to low, so adjust by shifting them out + long high = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, 0); + offset -= Long.BYTES; + high >>= (-offset * Byte.SIZE); + + return Int128.valueOf(high, low); + } + else if (bytes.length == 8) { + long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, 0); + long high = (low >> 63); + + return Int128.valueOf(high, low); + } + else { + long high = (bytes[0] >> 7); + long low = high; + for (int i = 0; i < bytes.length; i++) { + low = (low << 8) | (bytes[i] & 0xFF); + } + + return Int128.valueOf(high, low); + } + } + public static Int128 valueOf(long[] value) { if (value.length != 2) { diff --git a/core/trino-spi/src/test/java/io/trino/spi/type/TestInt128.java b/core/trino-spi/src/test/java/io/trino/spi/type/TestInt128.java new file mode 100644 index 0000000000000..02833e80285dc --- /dev/null +++ b/core/trino-spi/src/test/java/io/trino/spi/type/TestInt128.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.type; + +import org.testng.annotations.Test; + +import java.math.BigInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestInt128 +{ + @Test + public void testFromBigEndian() + { + byte[] bytes; + + // less than 8 bytes + bytes = new byte[] {0x1}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x0000000000000000L, 0x0000000000000001L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + bytes = new byte[] {(byte) 0xFF}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + // 8 bytes + bytes = new byte[] {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x0000000000000000L, 0x01_02_03_04_05_06_07_08L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + bytes = new byte[] {(byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0xFFFFFFFFFFFFFFFFL, 0x80_02_03_04_05_06_07_08L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + // more than 8 bytes, less than 16 bytes + bytes = new byte[] {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x000000000000_01_02L, 0x03_04_05_06_07_08_09_0AL)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + bytes = new byte[] {(byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0xFFFFFFFFFFFF_80_02L, 0x03_04_05_06_07_08_09_0AL)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + // 16 bytes + bytes = new byte[] {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x01_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + bytes = new byte[] {(byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x80_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + // more than 16 bytes + bytes = new byte[] {0x0, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x01_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + bytes = new byte[] {(byte) 0xFF, (byte) 0xFF, (byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55}; + assertThat(Int128.fromBigEndian(bytes)) + .isEqualTo(Int128.valueOf(0x80_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L)) + .isEqualTo(Int128.valueOf(new BigInteger(bytes))); + + // overflow + assertThatThrownBy(() -> Int128.fromBigEndian(new byte[] {0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0})) + .isInstanceOf(ArithmeticException.class); + + assertThatThrownBy(() -> Int128.fromBigEndian(new byte[] {(byte) 0xFE, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0})) + .isInstanceOf(ArithmeticException.class); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java index bcf67e6cbcd14..d8818c5599833 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java @@ -14,7 +14,6 @@ package io.trino.parquet; import io.trino.spi.type.DecimalType; -import io.trino.spi.type.Int128; import org.apache.parquet.column.Encoding; import org.apache.parquet.io.ColumnIO; import org.apache.parquet.io.ColumnIOFactory; @@ -28,7 +27,6 @@ import javax.annotation.Nullable; -import java.math.BigInteger; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -262,9 +260,4 @@ public static long getShortDecimalValue(byte[] bytes, int startOffset, int lengt return value; } - - public static Int128 getLongDecimalValue(byte[] bytes) - { - return Int128.valueOf(new BigInteger(bytes)); - } } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java index 32e2594ee75b6..2900471088273 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java @@ -59,7 +59,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.parquet.ParquetTimestampUtils.decodeInt64Timestamp; import static io.trino.parquet.ParquetTimestampUtils.decodeInt96Timestamp; -import static io.trino.parquet.ParquetTypeUtils.getLongDecimalValue; import static io.trino.parquet.ParquetTypeUtils.getShortDecimalValue; import static io.trino.parquet.predicate.PredicateUtils.isStatisticsOverflow; import static io.trino.plugin.base.type.TrinoTimestampEncoderFactory.createTimestampEncoder; @@ -297,8 +296,8 @@ private static Domain getDomain( } else { for (int i = 0; i < minimums.size(); i++) { - Int128 min = getLongDecimalValue(((Binary) minimums.get(i)).getBytes()); - Int128 max = getLongDecimalValue(((Binary) maximums.get(i)).getBytes()); + Int128 min = Int128.fromBigEndian(((Binary) minimums.get(i)).getBytes()); + Int128 max = Int128.fromBigEndian(((Binary) maximums.get(i)).getBytes()); ranges.add(Range.range(type, min, true, max, true)); } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/LongDecimalColumnReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/LongDecimalColumnReader.java index efec8d87e674a..81b74784fe7f9 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/LongDecimalColumnReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/LongDecimalColumnReader.java @@ -21,8 +21,6 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import java.math.BigInteger; - import static io.trino.spi.type.DecimalConversions.longToLongCast; import static io.trino.spi.type.DecimalConversions.longToShortCast; import static java.lang.String.format; @@ -49,7 +47,7 @@ protected void readValue(BlockBuilder blockBuilder, Type trinoType) DecimalType trinoDecimalType = (DecimalType) trinoType; Binary binary = valuesReader.readBytes(); - Int128 value = Int128.valueOf(new BigInteger(binary.getBytes())); + Int128 value = Int128.fromBigEndian(binary.getBytes()); if (trinoDecimalType.isShort()) { trinoType.writeLong(blockBuilder, longToShortCast( diff --git a/lib/trino-rcfile/src/main/java/io/trino/rcfile/binary/DecimalEncoding.java b/lib/trino-rcfile/src/main/java/io/trino/rcfile/binary/DecimalEncoding.java index 856dc15e15141..2d46c34095b30 100644 --- a/lib/trino-rcfile/src/main/java/io/trino/rcfile/binary/DecimalEncoding.java +++ b/lib/trino-rcfile/src/main/java/io/trino/rcfile/binary/DecimalEncoding.java @@ -21,8 +21,8 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.DecimalType; -import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; +import io.trino.spi.type.Int128Math; import io.trino.spi.type.Type; import java.math.BigInteger; @@ -179,13 +179,8 @@ private Int128 parseSlice(Slice slice, int offset) resultSlice.setBytes(BYTES_IN_LONG_DECIMAL - length, slice, offset, length); - // todo get rid of BigInteger - BigInteger decimal = new BigInteger(resultBytes); - if (scale != type.getScale()) { - decimal = Decimals.rescale(decimal, scale, type.getScale()); - } - - return Int128.valueOf(decimal); + Int128 result = Int128.fromBigEndian(resultBytes); + return Int128Math.rescale(result, type.getScale() - scale); } private void writeLong(SliceOutput output, long value) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/DecimalUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/DecimalUtils.java index 08daf904d37d9..22c1c0efb5e97 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/DecimalUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/DecimalUtils.java @@ -14,10 +14,9 @@ package io.trino.plugin.hive.util; import io.trino.spi.type.Int128; +import io.trino.spi.type.Int128Math; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import java.math.BigInteger; - import static io.trino.spi.type.Decimals.rescale; public final class DecimalUtils @@ -50,8 +49,7 @@ public static long getShortDecimalValue(byte[] bytes) public static Int128 getLongDecimalValue(HiveDecimalWritable writable, int columnScale) { - BigInteger value = new BigInteger(writable.getInternalStorage()); - value = rescale(value, writable.getScale(), columnScale); - return Int128.valueOf(value); + Int128 value = Int128.fromBigEndian(writable.getInternalStorage()); + return Int128Math.rescale(value, columnScale - writable.getScale()); } }