Skip to content

Commit

Permalink
Improve decoding performance for long decimals
Browse files Browse the repository at this point in the history
Avoid using intermediate BigIntegers when decoding long
decimals from Parquet and RCBinary
  • Loading branch information
martint committed Dec 23, 2021
1 parent 6ecfd3a commit 05ac13c
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 26 deletions.
55 changes: 55 additions & 0 deletions core/trino-spi/src/main/java/io/trino/spi/type/Int128.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,61 @@ private Int128(long high, long low)
this.low = low;
}

/**
* Decode an Int128 from the two's complement big-endian representation.
*
* @param bytes the two's complement big-endian encoding of the number. It must contain at least 1 byte.
* It may contain more than 16 bytes if the leading bytes are not significant (either zeros or -1)
* @throws ArithmeticException if the bytes represent a number outside of the range [-2^127, 2^127 - 1]
*/
public static Int128 fromBigEndian(byte[] bytes)
{
if (bytes.length >= 16) {
int offset = bytes.length - Long.BYTES;
long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset);

offset -= Long.BYTES;
long high = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset);

for (int i = 0; i < offset; i++) {
if (bytes[i] != (high >> 63)) {
throw new ArithmeticException("Overflow");
}
}

return Int128.valueOf(high, low);
}
else if (bytes.length > 8) {
// read the last 8 bytes into low
int offset = bytes.length - Long.BYTES;
long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset);

// At this point, we're guaranteed to have between 9 and 15 bytes available.
// Read 8 bytes into high, starting at offset 0. There will be some over-read
// of bytes belonging to low, so adjust by shifting them out
long high = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, 0);
offset -= Long.BYTES;
high >>= (-offset * Byte.SIZE);

return Int128.valueOf(high, low);
}
else if (bytes.length == 8) {
long low = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, 0);
long high = (low >> 63);

return Int128.valueOf(high, low);
}
else {
long high = (bytes[0] >> 7);
long low = high;
for (int i = 0; i < bytes.length; i++) {
low = (low << 8) | (bytes[i] & 0xFF);
}

return Int128.valueOf(high, low);
}
}

public static Int128 valueOf(long[] value)
{
if (value.length != 2) {
Expand Down
92 changes: 92 additions & 0 deletions core/trino-spi/src/test/java/io/trino/spi/type/TestInt128.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.spi.type;

import org.testng.annotations.Test;

import java.math.BigInteger;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestInt128
{
@Test
public void testFromBigEndian()
{
byte[] bytes;

// less than 8 bytes
bytes = new byte[] {0x1};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x0000000000000000L, 0x0000000000000001L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

bytes = new byte[] {(byte) 0xFF};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

// 8 bytes
bytes = new byte[] {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x0000000000000000L, 0x01_02_03_04_05_06_07_08L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

bytes = new byte[] {(byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0xFFFFFFFFFFFFFFFFL, 0x80_02_03_04_05_06_07_08L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

// more than 8 bytes, less than 16 bytes
bytes = new byte[] {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x000000000000_01_02L, 0x03_04_05_06_07_08_09_0AL))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

bytes = new byte[] {(byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0xFFFFFFFFFFFF_80_02L, 0x03_04_05_06_07_08_09_0AL))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

// 16 bytes
bytes = new byte[] {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x01_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

bytes = new byte[] {(byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x80_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

// more than 16 bytes
bytes = new byte[] {0x0, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x01_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

bytes = new byte[] {(byte) 0xFF, (byte) 0xFF, (byte) 0x80, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x55};
assertThat(Int128.fromBigEndian(bytes))
.isEqualTo(Int128.valueOf(0x80_02_03_04_05_06_07_08L, 0x09_0A_0B_0C_0D_0E_0F_55L))
.isEqualTo(Int128.valueOf(new BigInteger(bytes)));

// overflow
assertThatThrownBy(() -> Int128.fromBigEndian(new byte[] {0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}))
.isInstanceOf(ArithmeticException.class);

assertThatThrownBy(() -> Int128.fromBigEndian(new byte[] {(byte) 0xFE, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}))
.isInstanceOf(ArithmeticException.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.parquet;

import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import org.apache.parquet.column.Encoding;
import org.apache.parquet.io.ColumnIO;
import org.apache.parquet.io.ColumnIOFactory;
Expand All @@ -28,7 +27,6 @@

import javax.annotation.Nullable;

import java.math.BigInteger;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -262,9 +260,4 @@ public static long getShortDecimalValue(byte[] bytes, int startOffset, int lengt

return value;
}

public static Int128 getLongDecimalValue(byte[] bytes)
{
return Int128.valueOf(new BigInteger(bytes));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.parquet.ParquetTimestampUtils.decodeInt64Timestamp;
import static io.trino.parquet.ParquetTimestampUtils.decodeInt96Timestamp;
import static io.trino.parquet.ParquetTypeUtils.getLongDecimalValue;
import static io.trino.parquet.ParquetTypeUtils.getShortDecimalValue;
import static io.trino.parquet.predicate.PredicateUtils.isStatisticsOverflow;
import static io.trino.plugin.base.type.TrinoTimestampEncoderFactory.createTimestampEncoder;
Expand Down Expand Up @@ -297,8 +296,8 @@ private static Domain getDomain(
}
else {
for (int i = 0; i < minimums.size(); i++) {
Int128 min = getLongDecimalValue(((Binary) minimums.get(i)).getBytes());
Int128 max = getLongDecimalValue(((Binary) maximums.get(i)).getBytes());
Int128 min = Int128.fromBigEndian(((Binary) minimums.get(i)).getBytes());
Int128 max = Int128.fromBigEndian(((Binary) maximums.get(i)).getBytes());

ranges.add(Range.range(type, min, true, max, true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.io.api.Binary;

import java.math.BigInteger;

import static io.trino.spi.type.DecimalConversions.longToLongCast;
import static io.trino.spi.type.DecimalConversions.longToShortCast;
import static java.lang.String.format;
Expand All @@ -49,7 +47,7 @@ protected void readValue(BlockBuilder blockBuilder, Type trinoType)
DecimalType trinoDecimalType = (DecimalType) trinoType;

Binary binary = valuesReader.readBytes();
Int128 value = Int128.valueOf(new BigInteger(binary.getBytes()));
Int128 value = Int128.fromBigEndian(binary.getBytes());

if (trinoDecimalType.isShort()) {
trinoType.writeLong(blockBuilder, longToShortCast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;
import io.trino.spi.type.Type;

import java.math.BigInteger;
Expand Down Expand Up @@ -179,13 +179,8 @@ private Int128 parseSlice(Slice slice, int offset)

resultSlice.setBytes(BYTES_IN_LONG_DECIMAL - length, slice, offset, length);

// todo get rid of BigInteger
BigInteger decimal = new BigInteger(resultBytes);
if (scale != type.getScale()) {
decimal = Decimals.rescale(decimal, scale, type.getScale());
}

return Int128.valueOf(decimal);
Int128 result = Int128.fromBigEndian(resultBytes);
return Int128Math.rescale(result, type.getScale() - scale);
}

private void writeLong(SliceOutput output, long value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
package io.trino.plugin.hive.util;

import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;

import java.math.BigInteger;

import static io.trino.spi.type.Decimals.rescale;

public final class DecimalUtils
Expand Down Expand Up @@ -50,8 +49,7 @@ public static long getShortDecimalValue(byte[] bytes)

public static Int128 getLongDecimalValue(HiveDecimalWritable writable, int columnScale)
{
BigInteger value = new BigInteger(writable.getInternalStorage());
value = rescale(value, writable.getScale(), columnScale);
return Int128.valueOf(value);
Int128 value = Int128.fromBigEndian(writable.getInternalStorage());
return Int128Math.rescale(value, columnScale - writable.getScale());
}
}

0 comments on commit 05ac13c

Please sign in to comment.