Skip to content

Commit

Permalink
Add row IDs to batch reader
Browse files Browse the repository at this point in the history
  • Loading branch information
elharo authored and NikhilCollooru committed May 2, 2024
1 parent b0c575e commit 8081e0c
Show file tree
Hide file tree
Showing 16 changed files with 166 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation);
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent);
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ public HivePageSource(
if (columnMapping.getCoercionFrom().isPresent()) {
coercers[columnIndex] = createCoercer(typeManager, columnMapping.getCoercionFrom().get(), columnMapping.getHiveColumnHandle().getHiveType());
}
else if (isRowIdColumnHandle(columnMapping.getHiveColumnHandle()) && rowIdPartitionComponent.isPresent()) {
else if (isRowIdColumnHandle(columnMapping.getHiveColumnHandle())) {
// If there's no row ID partition component, then path + row numbers will be supplied for $row_id
byte[] component = rowIdPartitionComponent.orElse(new byte[0]);
String rowGroupId = Paths.get(path).getFileName().toString();
coercers[columnIndex] = new RowIDCoercer(rowIdPartitionComponent.get(), rowGroupId);
coercers[columnIndex] = new RowIDCoercer(component, rowGroupId);
}

if (columnMapping.getKind() == PREFILLED) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ public ConnectorPageSource createPageSource(
return createAggregatedPageSource(aggregatedPageSourceFactories, configuration, session, hiveSplit, hiveLayout, selectedColumns, fileContext, encryptionInformation);
}
if (hiveLayout.isPushdownFilterEnabled()) {
Optional<byte[]> rowIDPartitionComponent = hiveSplit.getRowIdPartitionComponent();
Optional<ConnectorPageSource> selectivePageSource = createSelectivePageSource(
selectivePageSourceFactories,
configuration,
Expand Down Expand Up @@ -371,12 +372,14 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(
.orElse(layout.getDomainPredicate());

for (HiveSelectivePageSourceFactory pageSourceFactory : selectivePageSourceFactories) {
List<HiveColumnHandle> columnHandles = toColumnHandles(columnMappings, true);
Optional<byte[]> rowIDPartitionComponent = split.getRowIdPartitionComponent();
Optional<? extends ConnectorPageSource> pageSource = pageSourceFactory.createPageSource(
configuration,
session,
split.getFileSplit(),
split.getStorage(),
toColumnHandles(columnMappings, true),
columnHandles,
prefilledValues,
coercers,
bucketAdaptation,
Expand All @@ -387,7 +390,7 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(
fileContext,
encryptionInformation,
layout.isAppendRowNumberEnabled(),
split.getRowIdPartitionComponent());
rowIDPartitionComponent);
if (pageSource.isPresent()) {
return Optional.of(pageSource.get());
}
Expand Down Expand Up @@ -497,7 +500,8 @@ public static Optional<ConnectorPageSource> createHivePageSource(
effectivePredicate,
hiveStorageTimeZone,
hiveFileContext,
encryptionInformation);
encryptionInformation,
rowIdPartitionComponent);
if (pageSource.isPresent()) {
HivePageSource hivePageSource = new HivePageSource(
columnMappings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ public final class HiveUtil
private static final String USE_RECORD_READER_FROM_INPUT_FORMAT_ANNOTATION = "UseRecordReaderFromInputFormat";
private static final String USE_FILE_SPLITS_FROM_INPUT_FORMAT_ANNOTATION = "UseFileSplitsFromInputFormat";

public static void checkRowIDPartitionComponent(List<HiveColumnHandle> columns, Optional<byte[]> rowIdPartitionComponent)
{
boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
if (supplyRowIDs) {
checkArgument(rowIdPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
}
}

static {
DateTimeParser[] timestampWithoutTimeZoneParser = {
DateTimeFormat.forPattern("yyyy-M-d").getParser(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -132,6 +133,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
dwrfEncryptionProvider,
session));
session,
rowIDPartitionComponent));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.Optional;

import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.orc.OrcSelectivePageSourceFactory.createOrcPageSource;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -118,6 +119,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
throw new PrestoException(HIVE_BAD_DATA, "ORC file is empty: " + fileSplit.getPath());
}

checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

return Optional.of(createOrcPageSource(
session,
DWRF,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.hive.FileFormatDataSourceStats;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.RowIDCoercer;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcBatchRecordReader;
import com.facebook.presto.orc.OrcCorruptionException;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class OrcBatchPageSource

private final Block[] constantBlocks;
private final int[] hiveColumnIndexes;
private final boolean[] rowIDColumnIndexes;

private int batchId;
private long completedPositions;
Expand All @@ -71,16 +73,20 @@ public class OrcBatchPageSource

private final List<Boolean> isRowPositionList;

private final RowIDCoercer coercer;

public OrcBatchPageSource(
OrcBatchRecordReader recordReader,
OrcDataSource orcDataSource,
List<HiveColumnHandle> columns,
TypeManager typeManager,
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false));
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false), rowIDPartitionComponent, rowGroupId);
}

/**
Expand All @@ -97,7 +103,9 @@ public OrcBatchPageSource(
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats,
List<Boolean> isRowPositionList)
List<Boolean> isRowPositionList,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this.recordReader = requireNonNull(recordReader, "recordReader is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -107,9 +115,12 @@ public OrcBatchPageSource(
this.stats = requireNonNull(stats, "stats is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.isRowPositionList = requireNonNull(isRowPositionList, "isRowPositionList is null");
// TODO don't create this if there's no rowID column
this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId);

this.constantBlocks = new Block[size];
this.hiveColumnIndexes = new int[size];
this.rowIDColumnIndexes = new boolean[size];

ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
Expand All @@ -124,6 +135,7 @@ public OrcBatchPageSource(
typesBuilder.add(type);

hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex();
rowIDColumnIndexes[columnIndex] = HiveColumnHandle.isRowIdColumnHandle(column);

if (!recordReader.isColumnPresent(column.getHiveColumnIndex())) {
constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_BATCH_SIZE);
Expand Down Expand Up @@ -183,6 +195,11 @@ public Page getNextPage()
if (isRowPositionColumn(fieldId)) {
blocks[fieldId] = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
}
else if (isRowIDColumn(fieldId)) {
Block rowNumbers = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
Block rowIDs = coercer.apply(rowNumbers);
blocks[fieldId] = rowIDs;
}
else {
if (constantBlocks[fieldId] != null) {
blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize);
Expand Down Expand Up @@ -260,6 +277,12 @@ private boolean isRowPositionColumn(int column)
return isRowPositionList.get(column);
}

private boolean isRowIDColumn(int column)
{
return this.rowIDColumnIndexes[column];
}

// TODO verify these are row numbers and rename?
private static Block getRowPosColumnBlock(long baseIndex, int size)
{
long[] rowPositions = new long[size];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcBloomFiltersEnabled;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.HiveUtil.getPhysicalHiveColumnHandles;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcDataSource;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcReader;
Expand Down Expand Up @@ -134,7 +135,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -169,7 +171,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
NO_ENCRYPTION,
session));
session,
rowIDPartitionComponent));
}

public static ConnectorPageSource createOrcPageSource(
Expand All @@ -191,9 +194,11 @@ public static ConnectorPageSource createOrcPageSource(
OrcReaderOptions orcReaderOptions,
Optional<EncryptionInformation> encryptionInformation,
DwrfEncryptionProvider dwrfEncryptionProvider,
ConnectorSession session)
ConnectorSession session,
Optional<byte[]> rowIDPartitionComponent)
{
checkArgument(domainCompactionThreshold >= 1, "domainCompactionThreshold must be at least 1");
checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

OrcDataSource orcDataSource = getOrcDataSource(session, fileSplit, hdfsEnvironment, configuration, hiveFileContext, stats);
Path path = new Path(fileSplit.getPath());
Expand Down Expand Up @@ -235,14 +240,18 @@ public static ConnectorPageSource createOrcPageSource(
systemMemoryUsage,
INITIAL_BATCH_SIZE);

byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupID = path.getName();
return new OrcBatchPageSource(
recordReader,
reader.getOrcDataSource(),
physicalColumns,
typeManager,
systemMemoryUsage,
stats,
hiveFileContext.getStats());
hiveFileContext.getStats(),
partitionID,
rowGroupID);
}
catch (Exception e) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ public static ConnectorPageSource createOrcPageSource(
Path path = new Path(fileSplit.getPath());

boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
String rowGroupId = path.getName();
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupId = path.getName();

DataSize maxMergeDistance = getOrcMaxMergeDistance(session);
DataSize tinyStripeThreshold = getOrcTinyStripeThreshold(session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!PageInputFormat.class.getSimpleName().equals(storage.getStorageFormat().getInputFormat())) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!PARQUET_SERDE_CLASS_NAMES.contains(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
RcFileEncoding rcFileEncoding;
if (LazyBinaryColumnarSerDe.class.getName().equals(storage.getStorageFormat().getSerDe())) {
Expand Down
Loading

0 comments on commit 8081e0c

Please sign in to comment.