Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.8] Parquet codec tests fix #4742

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.parquet.ParquetReadOptions;
import org.apache.parquet.column.page.PageReadStore;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.example.data.simple.SimpleGroup;
import org.apache.parquet.example.data.simple.convert.GroupRecordConverter;
import org.apache.parquet.hadoop.ParquetFileReader;
Expand Down Expand Up @@ -41,10 +42,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
Expand All @@ -60,6 +61,7 @@

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand Down Expand Up @@ -115,11 +117,12 @@ void test_happy_case(final int numberOfRecords) throws Exception {
parquetOutputCodec.writeEvent(event, outputStream);
}
parquetOutputCodec.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
assertThat(inputMaps.size(), equalTo(actualRecords.size()));
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(numberOfRecords).get(index);
Map expectedMap = inputMaps.get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
index++;
}
Expand All @@ -142,14 +145,16 @@ void test_happy_case_nullable_records(final int numberOfRecords) throws Exceptio
parquetOutputCodec.writeEvent(event, outputStream);
}
parquetOutputCodec.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
assertThat(inputMaps.size(), equalTo(actualRecords.size()));
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(numberOfRecords).get(index);
Map expectedMap = inputMaps.get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
index++;
}
outputStream.close();
tempFile.delete();
}

Expand All @@ -168,11 +173,12 @@ void test_happy_case_nullable_records_with_empty_maps(final int numberOfRecords)
parquetOutputCodec.writeEvent(event, outputStream);
}
parquetOutputCodec.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
assertThat(inputMaps.size(), equalTo(actualRecords.size()));
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(numberOfRecords).get(index);
Map expectedMap = inputMaps.get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
index++;
}
Expand All @@ -194,6 +200,9 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem
final Event eventWithInvalidField = mock(Event.class);
final String invalidFieldName = UUID.randomUUID().toString();
Map<String, Object> mapWithInvalid = generateRecords(1).get(0);
Map<String, Object> mapWithoutInvalid = mapWithInvalid.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
mapWithInvalid.put(invalidFieldName, UUID.randomUUID().toString());
when(eventWithInvalidField.toMap()).thenReturn(mapWithInvalid);
final ParquetOutputCodec objectUnderTest = createObjectUnderTest();
Expand All @@ -205,12 +214,12 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem
objectUnderTest.writeEvent(eventWithInvalidField, outputStream);

objectUnderTest.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(1).get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
assertThat(mapWithInvalid, not(Matchers.equalTo(actualMap)));
assertThat(mapWithoutInvalid, Matchers.equalTo(actualMap));
index++;
}
}
Expand Down Expand Up @@ -551,7 +560,29 @@ private static Schema createStandardInnerSchemaForNestedRecord(
return assembler.endRecord();
}

private List<Map<String, Object>> createParquetRecordsList(final InputStream inputStream) throws IOException {
private List<String> extractStringList(SimpleGroup group, String fieldName) {
int fieldIndex = group.getType().getFieldIndex(fieldName);
int repetitionCount = group.getGroup(fieldIndex, 0).getFieldRepetitionCount(0);
List<String> resultList = new ArrayList<>();
for (int i = 0; i < repetitionCount; i++) {
resultList.add(group.getGroup(fieldIndex, 0).getString(0, i));
}
return resultList;
}

private Map<String, Object> extractNestedGroup(SimpleGroup group, String fieldName) {

Map<String, Object> resultMap = new HashMap<>();
int fieldIndex = group.getType().getFieldIndex(fieldName);
int f1 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("firstFieldInNestedRecord");
resultMap.put("firstFieldInNestedRecord", group.getGroup(fieldIndex, 0).getString(f1,0));
int f2 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("secondFieldInNestedRecord");
resultMap.put("secondFieldInNestedRecord", group.getGroup(fieldIndex, 0).getInteger(f2,0));

return resultMap;
}

private List<Map<String, Object>> createParquetRecordsList(final InputStream inputStream) throws IOException, RuntimeException {

final File tempFile = new File(tempDirectory, FILE_NAME);
Files.copy(inputStream, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
Expand All @@ -567,15 +598,34 @@ private List<Map<String, Object>> createParquetRecordsList(final InputStream inp
final RecordReader<Group> recordReader = columnIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int row = 0; row < rows; row++) {
final Map<String, Object> eventData = new HashMap<>();
int fieldIndex = 0;
final SimpleGroup simpleGroup = (SimpleGroup) recordReader.read();
final GroupType groupType = simpleGroup.getType();


for (Type field : schema.getFields()) {
try {
eventData.put(field.getName(), simpleGroup.getValueToString(fieldIndex, 0));
} catch (Exception parquetException) {
LOG.error("Failed to parse Parquet", parquetException);
Object value;
int fieldIndex = groupType.getFieldIndex(field.getName());
if (simpleGroup.getFieldRepetitionCount(fieldIndex) == 0) {
continue;
}
switch (field.getName()) {
case "name": value = simpleGroup.getString(fieldIndex, 0);
break;
case "age": value = simpleGroup.getInteger(fieldIndex, 0);
break;
case "myLong": value = simpleGroup.getLong(fieldIndex, 0);
break;
case "myFloat": value = simpleGroup.getFloat(fieldIndex, 0);
break;
case "myDouble": value = simpleGroup.getDouble(fieldIndex, 0);
break;
case "myArray": value = extractStringList(simpleGroup, "myArray");
break;
case "nestedRecord": value = extractNestedGroup(simpleGroup, "nestedRecord");
break;
default: throw new IllegalArgumentException("Unknown field");
}
fieldIndex++;
eventData.put(field.getName(), value);
}
actualRecordList.add((HashMap) eventData);
}
Expand All @@ -591,4 +641,4 @@ private List<Map<String, Object>> createParquetRecordsList(final InputStream inp
private MessageType createdParquetSchema(ParquetMetadata parquetMetadata) {
return parquetMetadata.getFileMetaData().getSchema();
}
}
}
Loading