Skip to content

Commit

Permalink
WIP Avro and Test Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rmarrowstone committed Aug 20, 2024
1 parent ddb2f0a commit b96dc9e
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,23 @@ public AvroFileReader(
long offset,
OptionalLong length)
throws IOException, AvroTypeException
{
this(inputFile, schema, schema, avroTypeBlockHandler, offset, length);
}

public AvroFileReader(
TrinoInputFile inputFile,
Schema writerSchema,
Schema readerSchema,
AvroTypeBlockHandler avroTypeBlockHandler,
long offset,
OptionalLong length)
throws IOException, AvroTypeException
{
requireNonNull(inputFile, "inputFile is null");
requireNonNull(schema, "schema is null");
requireNonNull(readerSchema, "reader schema is null");
requireNonNull(writerSchema, "writer schema is null");

requireNonNull(avroTypeBlockHandler, "avroTypeBlockHandler is null");
long fileSize = inputFile.length();

Expand All @@ -69,7 +83,7 @@ public AvroFileReader(
end = length.stream().map(l -> l + offset).findFirst();
end.ifPresent(endLong -> verify(endLong <= fileSize, "offset plus length is greater than data size"));
input = new TrinoDataInputStream(inputFile.newStream());
dataReader = new AvroPageDataReader(schema, avroTypeBlockHandler);
dataReader = new AvroPageDataReader(writerSchema, readerSchema, avroTypeBlockHandler);
try {
fileReader = new DataFileReader<>(new TrinoDataInputStreamAsAvroSeekableInput(input, fileSize), dataReader);
fileReader.sync(offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ public class AvroPageDataReader
private RowBlockBuildingDecoder rowBlockBuildingDecoder;
private final AvroTypeBlockHandler typeManager;

public AvroPageDataReader(Schema readerSchema, AvroTypeBlockHandler typeManager)
public AvroPageDataReader(Schema schema, AvroTypeBlockHandler typeManager)
throws AvroTypeException
{
this(schema, schema, typeManager);
}

public AvroPageDataReader(Schema writerSchema, Schema readerSchema, AvroTypeBlockHandler typeManager)
throws AvroTypeException
{
this.readerSchema = requireNonNull(readerSchema, "readerSchema is null");
writerSchema = this.readerSchema;
this.writerSchema = requireNonNull(writerSchema, "writerSchema is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
verifyNoCircularReferences(readerSchema);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ public AvroPageSource(
avroFileReader = new AvroFileReader(inputFile, schema, avroTypeManager, offset, OptionalLong.of(length));
}

public AvroPageSource(
TrinoInputFile inputFile,
Schema writerSchema,
Schema readerSchema,
AvroTypeBlockHandler avroTypeManager,
long offset,
long length)
throws IOException, AvroTypeException
{
fileName = requireNonNull(inputFile, "inputFile is null").location().fileName();
avroFileReader = new AvroFileReader(inputFile, writerSchema, readerSchema, avroTypeManager, offset, OptionalLong.of(length));
}

@Override
public long getCompletedBytes()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ public Optional<ReaderPageSource> createPageSource(
}

try {
return Optional.of(new ReaderPageSource(new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length), readerProjections));
return Optional.of(
new ReaderPageSource(
new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length), readerProjections));
}
catch (IOException e) {
throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5733,7 +5733,7 @@ private String testReadWithPartitionSchemaMismatchAddedColumns(Session session,
public void testSubfieldReordering()
{
// Validate for formats for which subfield access is name based
List<HiveStorageFormat> formats = ImmutableList.of(HiveStorageFormat.ORC, HiveStorageFormat.PARQUET, HiveStorageFormat.AVRO);
List<HiveStorageFormat> formats = ImmutableList.of(HiveStorageFormat.AVRO);
String tableName = "evolve_test_" + randomNameSuffix();

for (HiveStorageFormat format : formats) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ private static void createTestFileTrino(
hiveFileWriter.commit();
}

private static void writeValue(Type type, BlockBuilder builder, Object object)
static void writeValue(Type type, BlockBuilder builder, Object object)
{
requireNonNull(builder, "builder is null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@
import io.trino.filesystem.TrinoOutputFile;
import io.trino.filesystem.memory.MemoryFileSystemFactory;
import io.trino.metastore.HiveType;
import io.trino.metastore.StorageFormat;
import io.trino.plugin.hive.avro.AvroFileWriterFactory;
import io.trino.plugin.hive.avro.AvroPageSourceFactory;
import io.trino.plugin.hive.line.OpenXJsonFileWriterFactory;
import io.trino.plugin.hive.line.OpenXJsonPageSourceFactory;
import io.trino.plugin.hive.util.HiveTypeTranslator;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.connector.ConnectorPageSource;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.testing.MaterializedResult;
import org.junit.jupiter.api.Assertions;
Expand All @@ -46,16 +54,19 @@

import static io.trino.hive.thrift.metastore.hive_metastoreConstants.FILE_INPUT_FORMAT;
import static io.trino.plugin.hive.HivePageSourceProvider.ColumnMapping.buildColumnMappings;
import static io.trino.plugin.hive.HiveStorageFormat.AVRO;
import static io.trino.plugin.hive.HiveStorageFormat.OPENX_JSON;
import static io.trino.plugin.hive.HiveTestUtils.getHiveSession;
import static io.trino.plugin.hive.HiveTestUtils.projectedColumn;
import static io.trino.plugin.hive.HiveTestUtils.toHiveBaseColumnHandle;
import static io.trino.plugin.hive.TestHiveFileFormats.writeValue;
import static io.trino.plugin.hive.acid.AcidTransaction.NO_ACID_TRANSACTION;
import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMNS;
import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMN_TYPES;
import static io.trino.plugin.hive.util.SerdeConstants.SERIALIZATION_LIB;
import static io.trino.spi.type.RowType.field;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static java.util.stream.Collectors.toList;

/**
* This test proves that non-dereferenced fields are pruned from nested RowTypes.
Expand Down Expand Up @@ -194,23 +205,72 @@ public void testProjectionsFromDifferentPartsOfSameBase()
List.of(false, 31));
}

@Test
public void testWriteThenRead()
throws IOException
{
HiveColumnHandle someOtherColumn = toHiveBaseColumnHandle("something_else", VarcharType.VARCHAR, 1);
List<HiveColumnHandle> writeColumns = List.of(tableColumns.get(0), someOtherColumn);

assertRoundTrip(
writeColumns,
List.of(
List.of(List.of(true, "bar", 31), "spam")),
writeColumns,
List.of(
someOtherColumn,
projectedColumn(tableColumns.get(0), "basic_int"),
projectedColumn(tableColumns.get(0), "basic_bool")),
List.of("spam", 31, true));
}

private void assertValues(List<HiveColumnHandle> projectedColumns, String text, List<Object> expected)
throws IOException
{
TrinoFileSystemFactory fileSystemFactory = new MemoryFileSystemFactory();
Location location = Location.of("memory:///test.ion");
Location location = Location.of("memory:///test");

final ConnectorSession session = getHiveSession(new HiveConfig());

writeTextFile(text, location, fileSystemFactory.create(session));
HivePageSourceFactory pageSourceFactory = new OpenXJsonPageSourceFactory(fileSystemFactory, new HiveConfig());

try (ConnectorPageSource pageSource = createPageSource(fileSystemFactory, location, tableColumns, projectedColumns, session)) {
try (ConnectorPageSource pageSource = createPageSource(pageSourceFactory, OPENX_JSON, fileSystemFactory, location, tableColumns, projectedColumns, session)) {
final MaterializedResult result = MaterializedResult.materializeSourceDataStream(session, pageSource, projectedColumns.stream().map(HiveColumnHandle::getType).toList());
Assertions.assertEquals(1, result.getRowCount());
Assertions.assertEquals(expected, result.getMaterializedRows().getFirst().getFields());
}
}

private void assertRoundTrip(
List<HiveColumnHandle> writeColumns,
List<Object> writeValues,
List<HiveColumnHandle> readColumns,
List<HiveColumnHandle> projections,
List<Object> expected)
throws IOException
{
TrinoFileSystemFactory fileSystemFactory = new MemoryFileSystemFactory();
Location location = Location.of("memory:///test");

final ConnectorSession session = getHiveSession(new HiveConfig());

writeObjectsToFile(
new AvroFileWriterFactory(fileSystemFactory, TESTING_TYPE_MANAGER, new NodeVersion("test_version")),
AVRO,
writeValues,
writeColumns,
location,
session);

HivePageSourceFactory pageSourceFactory = new AvroPageSourceFactory(fileSystemFactory);
try (ConnectorPageSource pageSource = createPageSource(pageSourceFactory, AVRO, fileSystemFactory, location, readColumns, projections, session)) {
final MaterializedResult result = MaterializedResult.materializeSourceDataStream(session, pageSource, projections.stream().map(HiveColumnHandle::getType).toList());
Assertions.assertEquals(1, result.getRowCount());
Assertions.assertEquals(expected, result.getMaterializedRows().getFirst().getFields());
}
}

private int writeTextFile(String text, Location location, TrinoFileSystem fileSystem)
throws IOException
{
Expand All @@ -225,19 +285,71 @@ private int writeTextFile(String text, Location location, TrinoFileSystem fileSy
return written;
}

private void writeObjectsToFile(
HiveFileWriterFactory fileWriterFactory,
HiveStorageFormat storageFormat,
List<Object> objects,
List<HiveColumnHandle> columns,
Location location,
ConnectorSession session) {

columns = columns.stream()
.filter(c -> c.getColumnType().equals(HiveColumnHandle.ColumnType.REGULAR))
.toList();
List<Type> types = columns.stream()
.map(HiveColumnHandle::getType)
.collect(toList());

PageBuilder pageBuilder = new PageBuilder(types);
for (Object row : objects) {
pageBuilder.declarePosition();
for (int col = 0; col < columns.size(); col++) {
Type type = types.get(col);
Object value = ((List<?>)row).get(col);

writeValue(type, pageBuilder.getBlockBuilder(col), value);
}
}
Page page = pageBuilder.build();

Map<String, String> tableProperties = ImmutableMap.<String, String>builder()
.put(LIST_COLUMNS, columns.stream().map(HiveColumnHandle::getName).collect(Collectors.joining(",")))
.put(LIST_COLUMN_TYPES, columns.stream().map(HiveColumnHandle::getType).map(HiveTypeTranslator::toHiveType).map(HiveType::toString).collect(Collectors.joining(",")))
.buildOrThrow();


Optional<FileWriter> fileWriter = fileWriterFactory.createFileWriter(
location,
columns.stream()
.map(HiveColumnHandle::getName)
.toList(),
storageFormat.toStorageFormat(),
HiveCompressionCodec.NONE,
tableProperties,
session,
OptionalInt.empty(),
NO_ACID_TRANSACTION,
false,
WriterKind.INSERT);

FileWriter hiveFileWriter = fileWriter.orElseThrow(() -> new IllegalArgumentException("fileWriterFactory"));
hiveFileWriter.appendRows(page);
hiveFileWriter.commit();
}

/**
* todo: this is very similar to what's in TestOrcPredicates, factor out.
*/
private static ConnectorPageSource createPageSource(
HivePageSourceFactory pageSourceFactory,
HiveStorageFormat storageFormat,
TrinoFileSystemFactory fileSystemFactory,
Location location,
List<HiveColumnHandle> tableColumns,
List<HiveColumnHandle> projectedColumns,
ConnectorSession session)
throws IOException
{
OpenXJsonPageSourceFactory factory = new OpenXJsonPageSourceFactory(fileSystemFactory, new HiveConfig());

long length = fileSystemFactory.create(session).newInputFile(location).length();

List<HivePageSourceProvider.ColumnMapping> columnMappings = buildColumnMappings(
Expand All @@ -252,14 +364,14 @@ private static ConnectorPageSource createPageSource(
Instant.now().toEpochMilli());

final Map<String, String> tableProperties = ImmutableMap.<String, String>builder()
.put(FILE_INPUT_FORMAT, OPENX_JSON.getInputFormat())
.put(SERIALIZATION_LIB, OPENX_JSON.getSerde())
.put(FILE_INPUT_FORMAT, storageFormat.getInputFormat())
.put(SERIALIZATION_LIB, storageFormat.getSerde())
.put(LIST_COLUMNS, tableColumns.stream().map(HiveColumnHandle::getName).collect(Collectors.joining(",")))
.put(LIST_COLUMN_TYPES, tableColumns.stream().map(HiveColumnHandle::getHiveType).map(HiveType::toString).collect(Collectors.joining(",")))
.buildOrThrow();

return HivePageSourceProvider.createHivePageSource(
ImmutableSet.of(factory),
ImmutableSet.of(pageSourceFactory),
session,
location,
OptionalInt.empty(),
Expand Down

0 comments on commit b96dc9e

Please sign in to comment.