From 4a62da19a7db2c13ce9a2292c2fef9182dc92d52 Mon Sep 17 00:00:00 2001 From: Krishna Kondaka <41027584+kkondaka@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:26:16 -0700 Subject: [PATCH] Parquet codec tests fix (#4698) Parquet codec tests fix Signed-off-by: Krishna Kondaka (cherry picked from commit a466013db00aafe995830c57f2d892d839bbf22c) --- .../codec/parquet/ParquetOutputCodecTest.java | 90 ++++++++++++++----- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java index b441a7a6e3..059b908aa4 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/ParquetOutputCodecTest.java @@ -11,6 +11,7 @@ import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.column.page.PageReadStore; import org.apache.parquet.example.data.Group; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.example.data.simple.SimpleGroup; import org.apache.parquet.example.data.simple.convert.GroupRecordConverter; import org.apache.parquet.hadoop.ParquetFileReader; @@ -41,10 +42,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.IOException; +import java.io.FileInputStream; import java.io.InputStream; +import java.io.IOException; +import java.io.File; import java.nio.file.Files; import java.nio.file.StandardCopyOption; import java.util.ArrayList; @@ -60,6 +61,7 @@ import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -115,11 +117,12 @@ void test_happy_case(final int numberOfRecords) throws Exception { parquetOutputCodec.writeEvent(event, outputStream); } parquetOutputCodec.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; + assertThat(inputMaps.size(), equalTo(actualRecords.size())); for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(numberOfRecords).get(index); + Map expectedMap = inputMaps.get(index); assertThat(expectedMap, Matchers.equalTo(actualMap)); index++; } @@ -142,14 +145,16 @@ void test_happy_case_nullable_records(final int numberOfRecords) throws Exceptio parquetOutputCodec.writeEvent(event, outputStream); } parquetOutputCodec.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; + assertThat(inputMaps.size(), equalTo(actualRecords.size())); for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(numberOfRecords).get(index); + Map expectedMap = inputMaps.get(index); assertThat(expectedMap, Matchers.equalTo(actualMap)); index++; } + outputStream.close(); tempFile.delete(); } @@ -168,11 +173,12 @@ void test_happy_case_nullable_records_with_empty_maps(final int numberOfRecords) parquetOutputCodec.writeEvent(event, outputStream); } parquetOutputCodec.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; + assertThat(inputMaps.size(), equalTo(actualRecords.size())); for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(numberOfRecords).get(index); + Map expectedMap = inputMaps.get(index); assertThat(expectedMap, Matchers.equalTo(actualMap)); index++; } @@ -194,6 +200,9 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem final Event eventWithInvalidField = mock(Event.class); final String invalidFieldName = UUID.randomUUID().toString(); Map mapWithInvalid = generateRecords(1).get(0); + Map mapWithoutInvalid = mapWithInvalid.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); mapWithInvalid.put(invalidFieldName, UUID.randomUUID().toString()); when(eventWithInvalidField.toMap()).thenReturn(mapWithInvalid); final ParquetOutputCodec objectUnderTest = createObjectUnderTest(); @@ -205,12 +214,12 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem objectUnderTest.writeEvent(eventWithInvalidField, outputStream); objectUnderTest.complete(outputStream); - List> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes())); + List> actualRecords = createParquetRecordsList(new FileInputStream(tempFile)); int index = 0; for (final Map actualMap : actualRecords) { assertThat(actualMap, notNullValue()); - Map expectedMap = generateRecords(1).get(index); - assertThat(expectedMap, Matchers.equalTo(actualMap)); + assertThat(mapWithInvalid, not(Matchers.equalTo(actualMap))); + assertThat(mapWithoutInvalid, Matchers.equalTo(actualMap)); index++; } } @@ -551,7 +560,29 @@ private static Schema createStandardInnerSchemaForNestedRecord( return assembler.endRecord(); } - private List> createParquetRecordsList(final InputStream inputStream) throws IOException { + private List extractStringList(SimpleGroup group, String fieldName) { + int fieldIndex = group.getType().getFieldIndex(fieldName); + int repetitionCount = group.getGroup(fieldIndex, 0).getFieldRepetitionCount(0); + List resultList = new ArrayList<>(); + for (int i = 0; i < repetitionCount; i++) { + resultList.add(group.getGroup(fieldIndex, 0).getString(0, i)); + } + return resultList; + } + + private Map extractNestedGroup(SimpleGroup group, String fieldName) { + + Map resultMap = new HashMap<>(); + int fieldIndex = group.getType().getFieldIndex(fieldName); + int f1 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("firstFieldInNestedRecord"); + resultMap.put("firstFieldInNestedRecord", group.getGroup(fieldIndex, 0).getString(f1,0)); + int f2 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("secondFieldInNestedRecord"); + resultMap.put("secondFieldInNestedRecord", group.getGroup(fieldIndex, 0).getInteger(f2,0)); + + return resultMap; + } + + private List> createParquetRecordsList(final InputStream inputStream) throws IOException, RuntimeException { final File tempFile = new File(tempDirectory, FILE_NAME); Files.copy(inputStream, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING); @@ -567,15 +598,34 @@ private List> createParquetRecordsList(final InputStream inp final RecordReader recordReader = columnIO.getRecordReader(pages, new GroupRecordConverter(schema)); for (int row = 0; row < rows; row++) { final Map eventData = new HashMap<>(); - int fieldIndex = 0; final SimpleGroup simpleGroup = (SimpleGroup) recordReader.read(); + final GroupType groupType = simpleGroup.getType(); + + for (Type field : schema.getFields()) { - try { - eventData.put(field.getName(), simpleGroup.getValueToString(fieldIndex, 0)); - } catch (Exception parquetException) { - LOG.error("Failed to parse Parquet", parquetException); + Object value; + int fieldIndex = groupType.getFieldIndex(field.getName()); + if (simpleGroup.getFieldRepetitionCount(fieldIndex) == 0) { + continue; + } + switch (field.getName()) { + case "name": value = simpleGroup.getString(fieldIndex, 0); + break; + case "age": value = simpleGroup.getInteger(fieldIndex, 0); + break; + case "myLong": value = simpleGroup.getLong(fieldIndex, 0); + break; + case "myFloat": value = simpleGroup.getFloat(fieldIndex, 0); + break; + case "myDouble": value = simpleGroup.getDouble(fieldIndex, 0); + break; + case "myArray": value = extractStringList(simpleGroup, "myArray"); + break; + case "nestedRecord": value = extractNestedGroup(simpleGroup, "nestedRecord"); + break; + default: throw new IllegalArgumentException("Unknown field"); } - fieldIndex++; + eventData.put(field.getName(), value); } actualRecordList.add((HashMap) eventData); } @@ -591,4 +641,4 @@ private List> createParquetRecordsList(final InputStream inp private MessageType createdParquetSchema(ParquetMetadata parquetMetadata) { return parquetMetadata.getFileMetaData().getSchema(); } -} \ No newline at end of file +}