Skip to content

Commit

Permalink
Remove Type equalTo and hash uses in Hive writers
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Oct 7, 2020
1 parent 63bc7ef commit f551d80
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 114 deletions.
67 changes: 6 additions & 61 deletions presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.ColumnarMap;
import io.prestosql.spi.block.ColumnarRow;
import io.prestosql.spi.type.AbstractLongType;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.CharType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.LongTimestamp;
import io.prestosql.spi.type.LongTimestampWithTimeZone;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.TimestampType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import org.openjdk.jol.info.ClassLayout;
Expand Down Expand Up @@ -464,9 +462,7 @@ public List<Long> getColumnHashes()

public static class WriteChecksumBuilder
{
private static final long NULL_HASH_CODE = 0x6e3efbd56c16a0cbL;

private final List<Type> types;
private final List<ValidationHash> validationHashes;
private long totalRowCount;
private final List<XxHash64> columnHashes;
private final XxHash64 stripeHash = new XxHash64();
Expand All @@ -476,7 +472,9 @@ public static class WriteChecksumBuilder

private WriteChecksumBuilder(List<Type> types)
{
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.validationHashes = requireNonNull(types, "types is null").stream()
.map(ValidationHash::createValidationHash)
.collect(toImmutableList());

ImmutableList.Builder<XxHash64> columnHashes = ImmutableList.builder();
for (Type ignored : types) {
Expand All @@ -503,71 +501,18 @@ public void addPage(Page page)
checkArgument(page.getChannelCount() == columnHashes.size(), "invalid page");

for (int channel = 0; channel < columnHashes.size(); channel++) {
Type type = types.get(channel);
ValidationHash validationHash = validationHashes.get(channel);
Block block = page.getBlock(channel);
XxHash64 xxHash64 = columnHashes.get(channel);
for (int position = 0; position < block.getPositionCount(); position++) {
long hash = hashPositionSkipNullMapKeys(type, block, position);
long hash = validationHash.hash(block, position);
longSlice.setLong(0, hash);
xxHash64.update(longBuffer);
}
}
totalRowCount += page.getPositionCount();
}

private static long hashPositionSkipNullMapKeys(Type type, Block block, int position)
{
if (block.isNull(position)) {
return NULL_HASH_CODE;
}

if (type instanceof MapType) {
Type keyType = type.getTypeParameters().get(0);
Type valueType = type.getTypeParameters().get(1);
Block mapBlock = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < mapBlock.getPositionCount(); i += 2) {
if (!mapBlock.isNull(i)) {
hash += hashPositionSkipNullMapKeys(keyType, mapBlock, i);
hash += hashPositionSkipNullMapKeys(valueType, mapBlock, i + 1);
}
}
return hash;
}

if (type instanceof ArrayType) {
Type elementType = type.getTypeParameters().get(0);
Block array = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < array.getPositionCount(); i++) {
hash = 31 * hash + hashPositionSkipNullMapKeys(elementType, array, i);
}
return hash;
}

if (type instanceof RowType) {
Block row = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < row.getPositionCount(); i++) {
Type elementType = type.getTypeParameters().get(i);
hash = 31 * hash + hashPositionSkipNullMapKeys(elementType, row, i);
}
return hash;
}

if (type instanceof TimestampType) {
// A flaw in ORC encoding makes it impossible to represent timestamp
// between 1969-12-31 23:59:59.000, exclusive, and 1970-01-01 00:00:00.000, exclusive.
// Therefore, such data won't round trip. The data read back is expected to be 1 second later than the original value.
long mills = floorDiv(TIMESTAMP_MILLIS.getLong(block, position), MICROSECONDS_PER_MILLISECOND);
if (mills > -1000 && mills < 0) {
return AbstractLongType.hash(mills + 1000);
}
}

return type.hash(block, position);
}

public WriteChecksum build()
{
return new WriteChecksum(
Expand Down
165 changes: 165 additions & 0 deletions presto-orc/src/main/java/io/prestosql/orc/ValidationHash.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.orc;

import io.prestosql.spi.block.Block;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.AbstractLongType;
import io.prestosql.spi.type.StandardTypes;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeOperators;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.type.StandardTypes.ARRAY;
import static io.prestosql.spi.type.StandardTypes.MAP;
import static io.prestosql.spi.type.StandardTypes.ROW;
import static io.prestosql.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static java.lang.invoke.MethodHandles.lookup;
import static java.util.Objects.requireNonNull;

class ValidationHash
{
// This value is a large arbitrary prime
private static final long NULL_HASH_CODE = 0x6e3efbd56c16a0cbL;

private static final MethodHandle MAP_HASH;
private static final MethodHandle ARRAY_HASH;
private static final MethodHandle ROW_HASH;
private static final MethodHandle TIMESTAMP_HASH;

static {
try {
MAP_HASH = lookup().findStatic(
ValidationHash.class,
"mapSkipNullKeysHash",
MethodType.methodType(long.class, Type.class, ValidationHash.class, ValidationHash.class, Block.class, int.class));
ARRAY_HASH = lookup().findStatic(
ValidationHash.class,
"arrayHash",
MethodType.methodType(long.class, Type.class, ValidationHash.class, Block.class, int.class));
ROW_HASH = lookup().findStatic(
ValidationHash.class,
"rowHash",
MethodType.methodType(long.class, Type.class, ValidationHash[].class, Block.class, int.class));
TIMESTAMP_HASH = lookup().findStatic(
ValidationHash.class,
"timestampHash",
MethodType.methodType(long.class, Block.class, int.class));
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

// This should really come from the environment, but there is not good way to get a value here
private static final TypeOperators VALIDATION_TYPE_OPERATORS_CACHE = new TypeOperators();

public static ValidationHash createValidationHash(Type type)
{
requireNonNull(type, "type is null");
if (type.getTypeSignature().getBase().equals(MAP)) {
ValidationHash keyHash = createValidationHash(type.getTypeParameters().get(0));
ValidationHash valueHash = createValidationHash(type.getTypeParameters().get(1));
return new ValidationHash(MAP_HASH.bindTo(type).bindTo(keyHash).bindTo(valueHash));
}

if (type.getTypeSignature().getBase().equals(ARRAY)) {
ValidationHash elementHash = createValidationHash(type.getTypeParameters().get(0));
return new ValidationHash(ARRAY_HASH.bindTo(type).bindTo(elementHash));
}

if (type.getTypeSignature().getBase().equals(ROW)) {
ValidationHash[] fieldHashes = type.getTypeParameters().stream()
.map(ValidationHash::createValidationHash)
.toArray(ValidationHash[]::new);
return new ValidationHash(ROW_HASH.bindTo(type).bindTo(fieldHashes));
}

if (type.getTypeSignature().getBase().equals(StandardTypes.TIMESTAMP)) {
return new ValidationHash(TIMESTAMP_HASH);
}

return new ValidationHash(VALIDATION_TYPE_OPERATORS_CACHE.getHashCodeOperator(type, InvocationConvention.simpleConvention(FAIL_ON_NULL, BLOCK_POSITION)));
}

private final MethodHandle hashCodeOperator;

private ValidationHash(MethodHandle hashCodeOperator)
{
this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null");
}

public long hash(Block block, int position)
{
if (block.isNull(position)) {
return NULL_HASH_CODE;
}
try {
return (long) hashCodeOperator.invokeExact(block, position);
}
catch (Throwable throwable) {
throwIfUnchecked(throwable);
throw new RuntimeException(throwable);
}
}

private static long mapSkipNullKeysHash(Type type, ValidationHash keyHash, ValidationHash valueHash, Block block, int position)
{
Block mapBlock = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < mapBlock.getPositionCount(); i += 2) {
if (!mapBlock.isNull(i)) {
hash += keyHash.hash(mapBlock, i) ^ valueHash.hash(mapBlock, i + 1);
}
}
return hash;
}

private static long arrayHash(Type type, ValidationHash elementHash, Block block, int position)
{
Block array = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < array.getPositionCount(); i++) {
hash = 31 * hash + elementHash.hash(array, i);
}
return hash;
}

private static long rowHash(Type type, ValidationHash[] fieldHashes, Block block, int position)
{
Block row = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < row.getPositionCount(); i++) {
hash = 31 * hash + fieldHashes[i].hash(row, i);
}
return hash;
}

private static long timestampHash(Block block, int position)
{
// A flaw in ORC encoding makes it impossible to represent timestamp
// between 1969-12-31 23:59:59.000, exclusive, and 1970-01-01 00:00:00.000, exclusive.
// Therefore, such data won't round trip. The data read back is expected to be 1 second later than the original value.
long millis = TIMESTAMP_MILLIS.getLong(block, position);
if (millis > -1000 && millis < 0) {
millis += 1000;
}
return AbstractLongType.hash(millis);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import io.airlift.slice.XxHash64;
import io.prestosql.spi.Page;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.Type;

import java.util.HashMap;
Expand All @@ -31,6 +28,7 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -114,10 +112,7 @@ public List<Long> getColumnHashes()

public static class WriteChecksumBuilder
{
// This value is a large arbitrary prime
private static final long NULL_HASH_CODE = 0x6e3efbd56c16a0cbL;

private final List<Type> types;
private final List<ValidationHash> validationHashes;
private long totalRowCount;
private final List<XxHash64> columnHashes;
private final XxHash64 rowGroupHash = new XxHash64();
Expand All @@ -127,7 +122,9 @@ public static class WriteChecksumBuilder

private WriteChecksumBuilder(List<Type> types)
{
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.validationHashes = requireNonNull(types, "types is null").stream()
.map(ValidationHash::createValidationHash)
.collect(toImmutableList());

ImmutableList.Builder<XxHash64> columnHashes = ImmutableList.builder();
for (Type ignored : types) {
Expand Down Expand Up @@ -167,60 +164,17 @@ public void addPage(Page page)

totalRowCount += page.getPositionCount();
for (int channel = 0; channel < columnHashes.size(); channel++) {
Type type = types.get(channel);
ValidationHash validationHash = validationHashes.get(channel);
Block block = page.getBlock(channel);
XxHash64 xxHash64 = columnHashes.get(channel);
for (int position = 0; position < block.getPositionCount(); position++) {
long hash = hashPositionSkipNullMapKeys(type, block, position);
long hash = validationHash.hash(block, position);
longSlice.setLong(0, hash);
xxHash64.update(longBuffer);
}
}
}

private static long hashPositionSkipNullMapKeys(Type type, Block block, int position)
{
if (block.isNull(position)) {
return NULL_HASH_CODE;
}

if (type instanceof MapType) {
Type keyType = type.getTypeParameters().get(0);
Type valueType = type.getTypeParameters().get(1);
Block mapBlock = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < mapBlock.getPositionCount(); i += 2) {
if (!mapBlock.isNull(i)) {
hash += hashPositionSkipNullMapKeys(keyType, mapBlock, i);
hash += hashPositionSkipNullMapKeys(valueType, mapBlock, i + 1);
}
}
return hash;
}

if (type instanceof ArrayType) {
Type elementType = type.getTypeParameters().get(0);
Block array = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < array.getPositionCount(); i++) {
hash = 31 * hash + hashPositionSkipNullMapKeys(elementType, array, i);
}
return hash;
}

if (type instanceof RowType) {
Block row = (Block) type.getObject(block, position);
long hash = 0;
for (int i = 0; i < row.getPositionCount(); i++) {
Type elementType = type.getTypeParameters().get(i);
hash = 31 * hash + hashPositionSkipNullMapKeys(elementType, row, i);
}
return hash;
}

return type.hash(block, position);
}

public WriteChecksum build()
{
return new WriteChecksum(
Expand Down
Loading

0 comments on commit f551d80

Please sign in to comment.