diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java index d6b4160888..059b908aa4 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java @@ -6,15 +6,18 @@ import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.column.page.PageReadStore; import org.apache.parquet.example.data.Group; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.example.data.simple.SimpleGroup; import org.apache.parquet.example.data.simple.convert.GroupRecordConverter; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.HadoopInputFile; import org.apache.parquet.io.ColumnIOFactory; -import org.apache.parquet.io.LocalInputFile; import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.io.RecordReader; import org.apache.parquet.schema.MessageType; @@ -39,12 +42,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.IOException; +import java.io.FileInputStream; import java.io.InputStream; +import java.io.IOException; +import java.io.File; import java.nio.file.Files; -import java.nio.file.Path; import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.Collections; @@ -59,6 +61,7 @@ import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -114,11 +117,12 @@ void test_happy_case(final int numberOfRecords) throws Exception { parquetOutputCodec.writeEvent(event, outputStream); } parquetOutputCodec.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; + assertThat(inputMaps.size(), equalTo(actualRecords.size())); for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(numberOfRecords).get(index); + Map expectedMap = inputMaps.get(index); assertThat(expectedMap, Matchers.equalTo(actualMap)); index++; } @@ -141,14 +145,16 @@ void test_happy_case_nullable_records(final int numberOfRecords) throws Exceptio parquetOutputCodec.writeEvent(event, outputStream); } parquetOutputCodec.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; + assertThat(inputMaps.size(), equalTo(actualRecords.size())); for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(numberOfRecords).get(index); + Map expectedMap = inputMaps.get(index); assertThat(expectedMap, Matchers.equalTo(actualMap)); index++; } + outputStream.close(); tempFile.delete(); } @@ -167,11 +173,12 @@ void test_happy_case_nullable_records_with_empty_maps(final int numberOfRecords) parquetOutputCodec.writeEvent(event, outputStream); } parquetOutputCodec.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; + assertThat(inputMaps.size(), equalTo(actualRecords.size())); for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(numberOfRecords).get(index); + Map expectedMap = inputMaps.get(index); assertThat(expectedMap, Matchers.equalTo(actualMap)); index++; } @@ -193,6 +200,9 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem final Event eventWithInvalidField = mock(Event.class); final String invalidFieldName = UUID.randomUUID().toString(); Map mapWithInvalid = generateRecords(1).get(0); + Map mapWithoutInvalid = mapWithInvalid.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); mapWithInvalid.put(invalidFieldName, UUID.randomUUID().toString()); when(eventWithInvalidField.toMap()).thenReturn(mapWithInvalid); final ParquetOutputCodec objectUnderTest = createObjectUnderTest(); @@ -204,12 +214,12 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem objectUnderTest.writeEvent(eventWithInvalidField, outputStream); objectUnderTest.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(1).get(index); - assertThat(expectedMap, Matchers.equalTo(actualMap)); + assertThat(mapWithInvalid, not(Matchers.equalTo(actualMap))); + assertThat(mapWithoutInvalid, Matchers.equalTo(actualMap)); index++; } } @@ -550,12 +560,34 @@ private static Schema createStandardInnerSchemaForNestedRecord( return assembler.endRecord(); } - private List> createParquetRecordsList(final InputStream inputStream) throws IOException { + private List extractStringList(SimpleGroup group, String fieldName) { + int fieldIndex = group.getType().getFieldIndex(fieldName); + int repetitionCount = group.getGroup(fieldIndex, 0).getFieldRepetitionCount(0); + List resultList = new ArrayList<>(); + for (int i = 0; i < repetitionCount; i++) { + resultList.add(group.getGroup(fieldIndex, 0).getString(0, i)); + } + return resultList; + } + + private Map extractNestedGroup(SimpleGroup group, String fieldName) { + + Map resultMap = new HashMap<>(); + int fieldIndex = group.getType().getFieldIndex(fieldName); + int f1 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("firstFieldInNestedRecord"); + resultMap.put("firstFieldInNestedRecord", group.getGroup(fieldIndex, 0).getString(f1,0)); + int f2 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("secondFieldInNestedRecord"); + resultMap.put("secondFieldInNestedRecord", group.getGroup(fieldIndex, 0).getInteger(f2,0)); + + return resultMap; + } + + private List> createParquetRecordsList(final InputStream inputStream) throws IOException, RuntimeException { final File tempFile = new File(tempDirectory, FILE_NAME); Files.copy(inputStream, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING); List> actualRecordList = new ArrayList<>(); - try (final ParquetFileReader parquetFileReader = new ParquetFileReader(new LocalInputFile(Path.of(tempFile.toURI())), ParquetReadOptions.builder().build())) { + try (ParquetFileReader parquetFileReader = new ParquetFileReader(HadoopInputFile.fromPath(new Path(tempFile.toURI()), new Configuration()), ParquetReadOptions.builder().build())) { final ParquetMetadata footer = parquetFileReader.getFooter(); final MessageType schema = createdParquetSchema(footer); PageReadStore pages; @@ -566,15 +598,34 @@ private List> createParquetRecordsList(final InputStream inp final RecordReader recordReader = columnIO.getRecordReader(pages, new GroupRecordConverter(schema)); for (int row = 0; row < rows; row++) { final Map eventData = new HashMap<>(); - int fieldIndex = 0; final SimpleGroup simpleGroup = (SimpleGroup) recordReader.read(); + final GroupType groupType = simpleGroup.getType(); + + for (Type field : schema.getFields()) { - try { - eventData.put(field.getName(), simpleGroup.getValueToString(fieldIndex, 0)); - } catch (Exception parquetException) { - LOG.error("Failed to parse Parquet", parquetException); + Object value; + int fieldIndex = groupType.getFieldIndex(field.getName()); + if (simpleGroup.getFieldRepetitionCount(fieldIndex) == 0) { + continue; + } + switch (field.getName()) { + case "name": value = simpleGroup.getString(fieldIndex, 0); + break; + case "age": value = simpleGroup.getInteger(fieldIndex, 0); + break; + case "myLong": value = simpleGroup.getLong(fieldIndex, 0); + break; + case "myFloat": value = simpleGroup.getFloat(fieldIndex, 0); + break; + case "myDouble": value = simpleGroup.getDouble(fieldIndex, 0); + break; + case "myArray": value = extractStringList(simpleGroup, "myArray"); + break; + case "nestedRecord": value = extractNestedGroup(simpleGroup, "nestedRecord"); + break; + default: throw new IllegalArgumentException("Unknown field"); } - fieldIndex++; + eventData.put(field.getName(), value); } actualRecordList.add((HashMap) eventData); } @@ -590,4 +641,4 @@ private List> createParquetRecordsList(final InputStream inp private MessageType createdParquetSchema(ParquetMetadata parquetMetadata) { return parquetMetadata.getFileMetaData().getSchema(); } -} \ No newline at end of file +}