Skip to content

Commit

Permalink
Remove usage of Hive types in HiveWriteUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum committed Sep 1, 2023
1 parent 6eb97e2 commit 13268e9
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 142 deletions.
6 changes: 6 additions & 0 deletions plugin/trino-hive-hadoop2/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino.hive</groupId>
<artifactId>hive-apache</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
Expand Down
11 changes: 6 additions & 5 deletions plugin/trino-hive/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,6 @@
<artifactId>hadoop-apache</artifactId>
</dependency>

<dependency>
<groupId>io.trino.hive</groupId>
<artifactId>hive-apache</artifactId>
</dependency>

<dependency>
<groupId>io.trino.hive</groupId>
<artifactId>hive-thrift</artifactId>
Expand Down Expand Up @@ -443,6 +438,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino.hive</groupId>
<artifactId>hive-apache</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino.tpch</groupId>
<artifactId>tpch</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.trino.hdfs.rubix.CachingTrinoS3FileSystem;
import io.trino.hdfs.s3.TrinoS3FileSystem;
import io.trino.plugin.hive.HiveReadOnlyException;
import io.trino.plugin.hive.HiveTimestampPrecision;
import io.trino.plugin.hive.HiveType;
import io.trino.plugin.hive.metastore.Database;
import io.trino.plugin.hive.metastore.Partition;
Expand All @@ -40,38 +39,29 @@
import io.trino.spi.block.Block;
import io.trino.spi.connector.SchemaNotFoundException;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.viewfs.ViewFileSystem;
import org.apache.hadoop.hdfs.DistributedFileSystem;
import org.apache.hadoop.hive.common.type.Date;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.Timestamp;
import org.apache.hadoop.io.Text;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;
import java.time.temporal.ChronoField;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.io.BaseEncoding.base16;
import static io.trino.hdfs.FileSystemUtils.getRawFileSystem;
import static io.trino.hdfs.s3.HiveS3Module.EMR_FS_CLASS_NAME;
Expand All @@ -90,26 +80,30 @@
import static io.trino.spi.type.Chars.padSpaces;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.Decimals.readBigDecimal;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static java.lang.Math.floorDiv;
import static java.lang.Math.floorMod;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
import static java.util.UUID.randomUUID;

public final class HiveWriteUtils
{
private static final DateTimeFormatter HIVE_DATE_FORMATTER = DateTimeFormatter.ofPattern("yyyy-MM-dd");
private static final DateTimeFormatter HIVE_TIMESTAMP_FORMATTER = new DateTimeFormatterBuilder()
.append(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))
.optionalStart().appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true).optionalEnd()
.toFormatter();

private HiveWriteUtils()
{
}
Expand All @@ -118,102 +112,66 @@ public static List<String> createPartitionValues(List<Type> partitionColumnTypes
{
ImmutableList.Builder<String> partitionValues = ImmutableList.builder();
for (int field = 0; field < partitionColumns.getChannelCount(); field++) {
Object value = getField(partitionColumnTypes.get(field), partitionColumns.getBlock(field), position);
if (value == null) {
partitionValues.add(HIVE_DEFAULT_DYNAMIC_PARTITION);
}
else {
String valueString = value.toString();
if (!CharMatcher.inRange((char) 0x20, (char) 0x7E).matchesAllOf(valueString)) {
throw new TrinoException(HIVE_INVALID_PARTITION_VALUE,
"Hive partition keys can only contain printable ASCII characters (0x20 - 0x7E). Invalid value: " +
base16().withSeparator(" ", 2).encode(valueString.getBytes(UTF_8)));
}
partitionValues.add(valueString);
String value = toPartitionValue(partitionColumnTypes.get(field), partitionColumns.getBlock(field), position);
if (!CharMatcher.inRange((char) 0x20, (char) 0x7E).matchesAllOf(value)) {
String encoded = base16().withSeparator(" ", 2).encode(value.getBytes(UTF_8));
throw new TrinoException(HIVE_INVALID_PARTITION_VALUE, "Hive partition keys can only contain printable ASCII characters (0x20 - 0x7E). Invalid value: " + encoded);
}
partitionValues.add(value);
}
return partitionValues.build();
}

private static Object getField(Type type, Block block, int position)
private static String toPartitionValue(Type type, Block block, int position)
{
// see HiveUtil#isValidPartitionType
if (block.isNull(position)) {
return null;
return HIVE_DEFAULT_DYNAMIC_PARTITION;
}
if (BOOLEAN.equals(type)) {
return BOOLEAN.getBoolean(block, position);
return String.valueOf(BOOLEAN.getBoolean(block, position));
}
if (BIGINT.equals(type)) {
return BIGINT.getLong(block, position);
return String.valueOf(BIGINT.getLong(block, position));
}
if (INTEGER.equals(type)) {
return INTEGER.getInt(block, position);
return String.valueOf(INTEGER.getInt(block, position));
}
if (SMALLINT.equals(type)) {
return SMALLINT.getShort(block, position);
return String.valueOf(SMALLINT.getShort(block, position));
}
if (TINYINT.equals(type)) {
return TINYINT.getByte(block, position);
return String.valueOf(TINYINT.getByte(block, position));
}
if (REAL.equals(type)) {
return REAL.getFloat(block, position);
return String.valueOf(REAL.getFloat(block, position));
}
if (DOUBLE.equals(type)) {
return DOUBLE.getDouble(block, position);
return String.valueOf(DOUBLE.getDouble(block, position));
}
if (type instanceof VarcharType varcharType) {
return new Text(varcharType.getSlice(block, position).getBytes());
return varcharType.getSlice(block, position).toStringUtf8();
}
if (type instanceof CharType charType) {
return new Text(padSpaces(charType.getSlice(block, position), charType).toStringUtf8());
}
if (VARBINARY.equals(type)) {
return VARBINARY.getSlice(block, position).getBytes();
return padSpaces(charType.getSlice(block, position), charType).toStringUtf8();
}
if (DATE.equals(type)) {
return Date.ofEpochDay(DATE.getInt(block, position));
return LocalDate.ofEpochDay(DATE.getInt(block, position)).format(HIVE_DATE_FORMATTER);
}
if (type instanceof TimestampType timestampType) {
return getHiveTimestamp(timestampType, block, position);
if (TIMESTAMP_MILLIS.equals(type)) {
long epochMicros = type.getLong(block, position);
long epochSeconds = floorDiv(epochMicros, MICROSECONDS_PER_SECOND);
int nanosOfSecond = floorMod(epochMicros, MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND;
return LocalDateTime.ofEpochSecond(epochSeconds, nanosOfSecond, ZoneOffset.UTC).format(HIVE_TIMESTAMP_FORMATTER);
}
if (type instanceof TimestampWithTimeZoneType) {
checkArgument(type.equals(TIMESTAMP_TZ_MILLIS));
return getHiveTimestampTz(block, position);
if (TIMESTAMP_TZ_MILLIS.equals(type)) {
long epochMillis = unpackMillisUtc(type.getLong(block, position));
return LocalDateTime.ofInstant(Instant.ofEpochMilli(epochMillis), ZoneOffset.UTC).format(HIVE_TIMESTAMP_FORMATTER);
}
if (type instanceof DecimalType decimalType) {
return getHiveDecimal(decimalType, block, position);
}
if (type instanceof ArrayType arrayType) {
Type elementType = arrayType.getElementType();
Block arrayBlock = block.getObject(position, Block.class);
List<Object> list = new ArrayList<>(arrayBlock.getPositionCount());
for (int i = 0; i < arrayBlock.getPositionCount(); i++) {
list.add(getField(elementType, arrayBlock, i));
}
return unmodifiableList(list);
}
if (type instanceof MapType mapType) {
Type keyType = mapType.getKeyType();
Type valueType = mapType.getValueType();
Block mapBlock = block.getObject(position, Block.class);
Map<Object, Object> map = new HashMap<>();
for (int i = 0; i < mapBlock.getPositionCount(); i += 2) {
map.put(getField(keyType, mapBlock, i),
getField(valueType, mapBlock, i + 1));
}
return unmodifiableMap(map);
}
if (type instanceof RowType rowType) {
List<Type> fieldTypes = rowType.getTypeParameters();
Block rowBlock = block.getObject(position, Block.class);
verify(fieldTypes.size() == rowBlock.getPositionCount(), "expected row value field count does not match type field count");
List<Object> row = new ArrayList<>(rowBlock.getPositionCount());
for (int i = 0; i < rowBlock.getPositionCount(); i++) {
row.add(getField(fieldTypes.get(i), rowBlock, i));
}
return unmodifiableList(row);
return readBigDecimal(decimalType, block, position).stripTrailingZeros().toPlainString();
}
throw new TrinoException(NOT_SUPPORTED, "unsupported type: " + type);
throw new TrinoException(NOT_SUPPORTED, "Unsupported type for partition: " + type);
}

public static void checkTableIsWritable(Table table, boolean writesToNonManagedTablesEnabled)
Expand Down Expand Up @@ -500,44 +458,4 @@ private static boolean isWritablePrimitiveType(PrimitiveCategory primitiveCatego
}
return false;
}

private static HiveDecimal getHiveDecimal(DecimalType decimalType, Block block, int position)
{
BigInteger unscaledValue;
if (decimalType.isShort()) {
unscaledValue = BigInteger.valueOf(decimalType.getLong(block, position));
}
else {
unscaledValue = ((Int128) decimalType.getObject(block, position)).toBigInteger();
}
return HiveDecimal.create(unscaledValue, decimalType.getScale());
}

private static Timestamp getHiveTimestamp(TimestampType type, Block block, int position)
{
verify(type.getPrecision() <= HiveTimestampPrecision.MAX.getPrecision(), "Timestamp precision too high for Hive");

long epochMicros;
int nanosOfMicro;
if (type.isShort()) {
epochMicros = type.getLong(block, position);
nanosOfMicro = 0;
}
else {
LongTimestamp timestamp = (LongTimestamp) type.getObject(block, position);
epochMicros = timestamp.getEpochMicros();
nanosOfMicro = timestamp.getPicosOfMicro() / PICOSECONDS_PER_NANOSECOND;
}

long epochSeconds = floorDiv(epochMicros, MICROSECONDS_PER_SECOND);
int microsOfSecond = floorMod(epochMicros, MICROSECONDS_PER_SECOND);
int nanosOfSecond = microsOfSecond * NANOSECONDS_PER_MICROSECOND + nanosOfMicro;
return Timestamp.ofEpochSecond(epochSeconds, nanosOfSecond);
}

private static Timestamp getHiveTimestampTz(Block block, int position)
{
long epochMillis = unpackMillisUtc(block.getLong(position, 0));
return Timestamp.ofEpochMilli(epochMillis);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,29 @@
package io.trino.plugin.hive.util;

import io.trino.hdfs.HdfsContext;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.SqlDecimal;
import io.trino.spi.type.Type;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.testng.annotations.Test;

import java.util.List;

import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT;
import static io.trino.plugin.hive.util.HiveWriteUtils.createPartitionValues;
import static io.trino.plugin.hive.util.HiveWriteUtils.isS3FileSystem;
import static io.trino.plugin.hive.util.HiveWriteUtils.isViewFileSystem;
import static io.trino.spi.type.DecimalType.createDecimalType;
import static io.trino.spi.type.Decimals.writeBigDecimal;
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.SqlDecimal.decimal;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

Expand Down Expand Up @@ -49,4 +64,49 @@ public void testIsViewFileSystem()
assertTrue(isViewFileSystem(CONTEXT, HDFS_ENVIRONMENT, viewfsPath));
assertFalse(isViewFileSystem(CONTEXT, HDFS_ENVIRONMENT, nonViewfsPath));
}

@Test
public void testCreatePartitionValuesDecimal()
{
assertCreatePartitionValuesDecimal(10, 0, "12345", "12345");
assertCreatePartitionValuesDecimal(10, 2, "123.45", "123.45");
assertCreatePartitionValuesDecimal(10, 2, "12345.00", "12345");
assertCreatePartitionValuesDecimal(5, 0, "12345", "12345");
assertCreatePartitionValuesDecimal(38, 2, "12345.00", "12345");
assertCreatePartitionValuesDecimal(38, 20, "12345.00000000000000000000", "12345");
assertCreatePartitionValuesDecimal(38, 20, "12345.67898000000000000000", "12345.67898");
}

private static void assertCreatePartitionValuesDecimal(int precision, int scale, String decimalValue, String expectedValue)
{
DecimalType decimalType = createDecimalType(precision, scale);
List<Type> types = List.of(decimalType);
SqlDecimal decimal = decimal(decimalValue, decimalType);

// verify the test values are as expected
assertThat(decimal.toString()).isEqualTo(decimalValue);
assertThat(decimal.toBigDecimal().toString()).isEqualTo(decimalValue);

PageBuilder pageBuilder = new PageBuilder(types);
pageBuilder.declarePosition();
writeDecimal(decimalType, decimal, pageBuilder.getBlockBuilder(0));
Page page = pageBuilder.build();

// verify the expected value against HiveDecimal
assertThat(HiveDecimal.create(decimal.toBigDecimal()).toString())
.isEqualTo(expectedValue);

assertThat(createPartitionValues(types, page, 0))
.isEqualTo(List.of(expectedValue));
}

private static void writeDecimal(DecimalType decimalType, SqlDecimal decimal, BlockBuilder blockBuilder)
{
if (decimalType.isShort()) {
writeShortDecimal(blockBuilder, decimal.toBigDecimal().unscaledValue().longValue());
}
else {
writeBigDecimal(decimalType, blockBuilder, decimal.toBigDecimal());
}
}
}
7 changes: 0 additions & 7 deletions plugin/trino-hudi/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,6 @@
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-hive</artifactId>
<exclusions>
<!-- TODO: remove when removed from trino-hive -->
<exclusion>
<groupId>io.trino.hive</groupId>
<artifactId>hive-apache</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
Expand Down
Loading

0 comments on commit 13268e9

Please sign in to comment.