Skip to content

Commit

Permalink
Fix bug where ingestion failed for input document containing list of …
Browse files Browse the repository at this point in the history
…nested objects

Signed-off-by: Yizhe Liu <[email protected]>
  • Loading branch information
yizheliu-amazon committed Dec 22, 2024
1 parent 22ba5d3 commit 5b13778
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,23 +419,34 @@ private void putNLPResultToSourceMapForMapType(
if (sourceValue instanceof Map) {
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) {
if (sourceAndMetadataMap.get(processorKey) instanceof List) {
// build nlp output for list of nested objects
Iterator<Object> inputNestedMapValueIt = ((List<Object>) inputNestedMapEntry.getValue()).iterator();
for (Map<String, Object> nestedElement : (List<Map<String, Object>>) sourceAndMetadataMap.get(processorKey)) {
// Only fill in when value is not null
if (inputNestedMapValueIt.hasNext() && inputNestedMapValueIt.next() != null) {
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
if (inputNestedMapEntry.getValue() instanceof List) {
// build nlp output for object in sourceValue which is list type
Iterator<Object> inputNestedMapValueIt = ((List<Object>) inputNestedMapEntry.getValue()).iterator();
for (Map<String, Object> nestedElement : (List<Map<String, Object>>) sourceAndMetadataMap.get(processorKey)) {
// Only fill in when value is not null
if (inputNestedMapValueIt.hasNext() && inputNestedMapValueIt.next() != null) {
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
}
}
} else if (inputNestedMapEntry.getValue() instanceof Map) {
// build nlp output for object in sourceValue which is map type
List<Map<String, Object>> nestedElementList = (List<Map<String, Object>>) sourceAndMetadataMap.get(processorKey);

IntStream.range(0, nestedElementList.size()).forEach(nestedElementIndex -> {
Map<String, Object> nestedElement = nestedElementList.get(nestedElementIndex);
putNLPResultToSingleSourceMapInList(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
indexWrapper,
nestedElement,
nestedElementIndex
);
});
}
} else {
Pair<String, Object> processedNestedKey = processNestedKey(inputNestedMapEntry);
Map<String, Object> sourceMap;
if (sourceAndMetadataMap.get(processorKey) == null) {
sourceMap = new HashMap<>();
sourceAndMetadataMap.put(processorKey, sourceMap);
} else {
sourceMap = (Map<String, Object>) sourceAndMetadataMap.get(processorKey);
}
Map<String, Object> sourceMap = getSourceMapBySourceAndMetadataMap(processorKey, sourceAndMetadataMap);
putNLPResultToSourceMapForMapType(
processedNestedKey.getKey(),
processedNestedKey.getValue(),
Expand All @@ -456,6 +467,58 @@ private void putNLPResultToSourceMapForMapType(
}
}

/**
* Put nlp result to single source element, which is in a list field of source document
* Such source element is in map type
*
* @param processorKey
* @param sourceValue
* @param results
* @param indexWrapper
* @param sourceAndMetadataMap
* @param nestedElementIndex index of the element in the list field of source document
*/
@SuppressWarnings("unchecked")
private void putNLPResultToSingleSourceMapInList(
String processorKey,
Object sourceValue,
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap,
int nestedElementIndex
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
if (sourceValue instanceof Map) {
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) {
Pair<String, Object> processedNestedKey = processNestedKey(inputNestedMapEntry);
Map<String, Object> sourceMap = getSourceMapBySourceAndMetadataMap(processorKey, sourceAndMetadataMap);
putNLPResultToSingleSourceMapInList(
processedNestedKey.getKey(),
processedNestedKey.getValue(),
results,
indexWrapper,
sourceMap,
nestedElementIndex
);
}
} else {
if (sourceValue instanceof List && ((List<Object>) sourceValue).get(nestedElementIndex) != null) {
sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
}
}
}

@SuppressWarnings("unchecked")
private Map<String, Object> getSourceMapBySourceAndMetadataMap(String processorKey, Map<String, Object> sourceAndMetadataMap) {
Map<String, Object> sourceMap = new HashMap<>();
if (sourceAndMetadataMap.get(processorKey) == null) {
sourceAndMetadataMap.put(processorKey, sourceMap);
} else {
sourceMap = (Map<String, Object>) sourceAndMetadataMap.get(processorKey);
}
return sourceMap;
}

private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.Before;
Expand Down Expand Up @@ -486,6 +487,76 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi
}
}

@SneakyThrows
@SuppressWarnings("unchecked")
public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWithoutDestinationStructure_theSuccessful() {
/*
modeling following document:
parent: [
child_level_1:
child_level_1_text_field: "text_value",
child_level_1:
child_level_1_text_field: "text_value",
]
*/
Map<String, Object> child1Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1);
Map<String, Object> child1Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child1Level2);
Map<String, Object> child2Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1);
Map<String, Object> child2Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child2Level2);
Map<String, Object> sourceAndMetadata = Map.of(
PARENT_FIELD,
Arrays.asList(child1Level1, child2Level1),
IndexFieldMapper.NAME,
"my_index"
);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());

Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(
TextEmbeddingProcessor.FIELD_MAP_FIELD,
Map.of(
PARENT_FIELD,
Map.of(CHILD_FIELD_LEVEL_1, Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD)))
)
);
TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(
registry,
PROCESSOR_TAG,
DESCRIPTION,
config
);

List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(2, 100, 0.0f, 1.0f);
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {});
assertNotNull(ingestDocument);
assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD));

List<Map<String, Object>> parentAfterProcessor = (List<Map<String, Object>>) ingestDocument.getSourceAndMetadata()
.get(PARENT_FIELD);

for (Map<String, Object> childActual : parentAfterProcessor) {
Map<String, Object> childLevel1Actual = (Map<String, Object>) childActual.get(CHILD_FIELD_LEVEL_1);
assertEquals(2, childLevel1Actual.size());
assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD));
assertNotNull(childLevel1Actual.get(CHILD_FIELD_LEVEL_2));
Map<String, Object> childLevel2Actual = (Map<String, Object>) childLevel1Actual.get(CHILD_FIELD_LEVEL_2);
List<Float> vectors = (List<Float>) childLevel2Actual.get(CHILD_LEVEL_2_KNN_FIELD);
assertEquals(100, vectors.size());
for (Float vector : vectors) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}
}

}

@SneakyThrows
public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down Expand Up @@ -769,6 +840,68 @@ public void testBuildVectorOutput_withNestedList_Level2_successful() {
assertNotNull(nestedObj.get(1).get("vectorField"));
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedListLevel2_withNestedFields_successful() {
Map<String, Object> config = createNestedList2LevelConfiguration();
IngestDocument ingestDocument = create2LevelNestedListWithNestedFieldsIngestDocument();
TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument);
List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(2, 2, 0.0f, 1.0f);
textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata());
List<Map<String, Object>> nestedObj = (List<Map<String, Object>>) ingestDocument.getSourceAndMetadata().get("nestedField");

for (Map<String, Object> singleNestedObjMap : nestedObj) {
Map<String, Object> singleNestedObj = (Map<String, Object>) singleNestedObjMap.get("nestedField");
assertTrue(singleNestedObj.containsKey("vectorField"));
assertNotNull(singleNestedObj.get("vectorField"));
}
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedListLevel2_withPartialNullNestedFields_successful() {
Map<String, Object> config = createNestedList2LevelConfiguration();
IngestDocument ingestDocument = create2LevelNestedListWithNestedFieldsIngestDocument();
/**
* Ingest doc with below fields
* "nestedField": {
* "nestedField": [
* {
* "nestedField": {
* "textField": null,
* }
* },
* {
* "nestedField": {
* "textField": "This is another text field",
* }
* }
* ]
* }
*/
List<Map<String, Object>> nestedList = (List<Map<String, Object>>) ingestDocument.getSourceAndMetadata().get("nestedField");
Map<String, Object> objWithNullText = buildObjMapWithSingleField("textField", null);
Map<String, Object> nestedObjWithNullText = buildObjMapWithSingleField("nestedField", objWithNullText);
nestedList.set(0, nestedObjWithNullText);
TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument);
List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(2, 2, 0.0f, 1.0f);
textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata());
List<Map<String, Object>> nestedObj = (List<Map<String, Object>>) ingestDocument.getSourceAndMetadata().get("nestedField");

IntStream.range(0, nestedObj.size()).forEachOrdered(index -> {
Map<String, Object> singleNestedObjMap = nestedObj.get(index);
if (index == 0) {
Map<String, Object> singleNestedObj = (Map<String, Object>) singleNestedObjMap.get("nestedField");
assertFalse(singleNestedObj.containsKey("vectorField"));
assertTrue(singleNestedObj.containsKey("textField"));
} else {
Map<String, Object> singleNestedObj = (Map<String, Object>) singleNestedObjMap.get("nestedField");
assertTrue(singleNestedObj.containsKey("vectorField"));
assertNotNull(singleNestedObj.get("vectorField"));
}
});
}

@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_successful() {
Map<String, Object> config = createNestedList2LevelConfiguration();
Expand Down Expand Up @@ -1043,55 +1176,55 @@ private IngestDocument createNestedMapIngestDocument() {
}

private Map<String, Object> createNestedListConfiguration() {
Map<String, Object> nestedConfig = new HashMap<>();
nestedConfig.put("textField", "vectorField");
Map<String, Object> result = new HashMap<>();
result.put("nestedField", nestedConfig);
return result;
Map<String, Object> nestedConfig = buildObjMapWithSingleField("textField", "vectorField");
return buildObjMapWithSingleField("nestedField", nestedConfig);
}

private Map<String, Object> createNestedList2LevelConfiguration() {
Map<String, Object> nestedConfig = new HashMap<>();
nestedConfig.put("textField", "vectorField");
Map<String, Object> nestConfigLevel1 = new HashMap<>();
nestConfigLevel1.put("nestedField", nestedConfig);
Map<String, Object> result = new HashMap<>();
result.put("nestedField", nestConfigLevel1);
return result;
Map<String, Object> nestedConfig = buildObjMapWithSingleField("textField", "vectorField");
Map<String, Object> nestConfigLevel1 = buildObjMapWithSingleField("nestedField", nestedConfig);
return buildObjMapWithSingleField("nestedField", nestConfigLevel1);
}

private IngestDocument createNestedListIngestDocument() {
HashMap<String, Object> nestedObj1 = new HashMap<>();
nestedObj1.put("textField", "This is a text field");
HashMap<String, Object> nestedObj2 = new HashMap<>();
nestedObj2.put("textField", "This is another text field");
HashMap<String, Object> nestedList = new HashMap<>();
nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2));
Map<String, Object> nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field");
Map<String, Object> nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field");
Map<String, Object> nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2));
return new IngestDocument(nestedList, new HashMap<>());
}

private IngestDocument createNestedListWithNotEmbeddingFieldIngestDocument() {
HashMap<String, Object> nestedObj1 = new HashMap<>();
nestedObj1.put("textFieldNotForEmbedding", "This is a text field");
HashMap<String, Object> nestedObj2 = new HashMap<>();
nestedObj2.put("textField", "This is another text field");
HashMap<String, Object> nestedList = new HashMap<>();
nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2));
Map<String, Object> nestedObj1 = buildObjMapWithSingleField("textFieldNotForEmbedding", "This is a text field");
Map<String, Object> nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field");
Map<String, Object> nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2));
return new IngestDocument(nestedList, new HashMap<>());
}

private IngestDocument create2LevelNestedListIngestDocument() {
HashMap<String, Object> nestedObj1 = new HashMap<>();
nestedObj1.put("textField", "This is a text field");
HashMap<String, Object> nestedObj2 = new HashMap<>();
nestedObj2.put("textField", "This is another text field");
HashMap<String, Object> nestedList = new HashMap<>();
nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2));
HashMap<String, Object> nestedList1 = new HashMap<>();
nestedList1.put("nestedField", nestedList);
Map<String, Object> nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field");
Map<String, Object> nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field");
Map<String, Object> nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2));
Map<String, Object> nestedList1 = buildObjMapWithSingleField("nestedField", nestedList);
return new IngestDocument(nestedList1, new HashMap<>());
}

private IngestDocument create2LevelNestedListWithNestedFieldsIngestDocument() {
Map<String, Object> nestedObj1Level2 = buildObjMapWithSingleField("textField", "This is a text field");
Map<String, Object> nestedObj1Level1 = buildObjMapWithSingleField("nestedField", nestedObj1Level2);

Map<String, Object> nestedObj2Level2 = buildObjMapWithSingleField("textField", "This is another text field");
Map<String, Object> nestedObj2Level1 = buildObjMapWithSingleField("nestedField", nestedObj2Level2);

Map<String, Object> nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1Level1, nestedObj2Level1));
return new IngestDocument(nestedList, new HashMap<>());
}

private Map<String, Object> buildObjMapWithSingleField(String fieldName, Object fieldValue) {
Map<String, Object> objMap = new HashMap<>();
objMap.put(fieldName, fieldValue);
return objMap;
}

private IngestDocument create2LevelNestedListWithNotEmbeddingFieldIngestDocument() {
HashMap<String, Object> nestedObj1 = new HashMap<>();
nestedObj1.put("textFieldNotForEmbedding", "This is a text field");
Expand Down

0 comments on commit 5b13778

Please sign in to comment.