From 22f7c58fcaf8ef7ad811f513bf0b17e7eef24fec Mon Sep 17 00:00:00 2001 From: Yizhe Liu <“yizheliu@amazon.com”> Date: Sat, 21 Dec 2024 14:26:47 -0800 Subject: [PATCH] Fix bug where ingestion failed for input document with list of nested objects Signed-off-by: Yizhe Liu --- CHANGELOG.md | 1 + .../processor/InferenceProcessor.java | 89 ++++++-- .../TextEmbeddingProcessorTests.java | 197 +++++++++++++++--- 3 files changed, 242 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5345d416f..3653e18ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) +- Fix bug where ingested document has list of nested objects ([#1039](https://github.com/opensearch-project/neural-search/pull/1039)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index ae996251d..9966dd802 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -419,23 +419,34 @@ private void putNLPResultToSourceMapForMapType( if (sourceValue instanceof Map) { for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { if (sourceAndMetadataMap.get(processorKey) instanceof List) { - // build nlp output for list of nested objects - Iterator inputNestedMapValueIt = ((List) inputNestedMapEntry.getValue()).iterator(); - for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { - // Only fill in when value is not null - if (inputNestedMapValueIt.hasNext() && inputNestedMapValueIt.next() != null) { - nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); + if (inputNestedMapEntry.getValue() instanceof List) { + // build nlp output for object in sourceValue which is list type + Iterator inputNestedMapValueIt = ((List) inputNestedMapEntry.getValue()).iterator(); + for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { + // Only fill in when value is not null + if (inputNestedMapValueIt.hasNext() && inputNestedMapValueIt.next() != null) { + nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); + } } + } else if (inputNestedMapEntry.getValue() instanceof Map) { + // build nlp output for object in sourceValue which is map type + List> nestedElementList = (List>) sourceAndMetadataMap.get(processorKey); + + IntStream.range(0, nestedElementList.size()).forEach(nestedElementIndex -> { + Map nestedElement = nestedElementList.get(nestedElementIndex); + putNLPResultToSingleSourceMapInList( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + nestedElement, + nestedElementIndex + ); + }); } } else { Pair processedNestedKey = processNestedKey(inputNestedMapEntry); - Map sourceMap; - if (sourceAndMetadataMap.get(processorKey) == null) { - sourceMap = new HashMap<>(); - sourceAndMetadataMap.put(processorKey, sourceMap); - } else { - sourceMap = (Map) sourceAndMetadataMap.get(processorKey); - } + Map sourceMap = getSourceMapBySourceAndMetadataMap(processorKey, sourceAndMetadataMap); putNLPResultToSourceMapForMapType( processedNestedKey.getKey(), processedNestedKey.getValue(), @@ -456,6 +467,58 @@ private void putNLPResultToSourceMapForMapType( } } + /** + * Put nlp result to single source element, which is in a list field of source document + * Such source element is in map type + * + * @param processorKey + * @param sourceValue + * @param results + * @param indexWrapper + * @param sourceAndMetadataMap + * @param nestedElementIndex index of the element in the list field of source document + */ + @SuppressWarnings("unchecked") + private void putNLPResultToSingleSourceMapInList( + String processorKey, + Object sourceValue, + List results, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap, + int nestedElementIndex + ) { + if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + Pair processedNestedKey = processNestedKey(inputNestedMapEntry); + Map sourceMap = getSourceMapBySourceAndMetadataMap(processorKey, sourceAndMetadataMap); + putNLPResultToSingleSourceMapInList( + processedNestedKey.getKey(), + processedNestedKey.getValue(), + results, + indexWrapper, + sourceMap, + nestedElementIndex + ); + } + } else { + if (sourceValue instanceof List && ((List) sourceValue).get(nestedElementIndex) != null) { + sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION); + } + } + } + + @SuppressWarnings("unchecked") + private Map getSourceMapBySourceAndMetadataMap(String processorKey, Map sourceAndMetadataMap) { + Map sourceMap = new HashMap<>(); + if (sourceAndMetadataMap.get(processorKey) == null) { + sourceAndMetadataMap.put(processorKey, sourceMap); + } else { + sourceMap = (Map) sourceAndMetadataMap.get(processorKey); + } + return sourceMap; + } + private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 97e85e46e..5e1db8f10 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -27,6 +27,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; +import java.util.stream.IntStream; import org.apache.commons.lang3.tuple.Pair; import org.junit.Before; @@ -486,6 +487,76 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi } } + @SneakyThrows + @SuppressWarnings("unchecked") + public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWithoutDestinationStructure_theSuccessful() { + /* + modeling following document: + parent: [ + child_level_1: + child_level_1_text_field: "text_value", + child_level_1: + child_level_1_text_field: "text_value", + ] + */ + Map child1Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map child1Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child1Level2); + Map child2Level2 = buildObjMapWithSingleField(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map child2Level1 = buildObjMapWithSingleField(CHILD_FIELD_LEVEL_1, child2Level2); + Map sourceAndMetadata = Map.of( + PARENT_FIELD, + Arrays.asList(child1Level1, child2Level1), + IndexFieldMapper.NAME, + "my_index" + ); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + Map.of( + PARENT_FIELD, + Map.of(CHILD_FIELD_LEVEL_1, Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD))) + ) + ); + TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + List> modelTensorList = createRandomOneDimensionalMockVector(2, 100, 0.0f, 1.0f); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + assertNotNull(ingestDocument); + assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)); + + List> parentAfterProcessor = (List>) ingestDocument.getSourceAndMetadata() + .get(PARENT_FIELD); + + for (Map childActual : parentAfterProcessor) { + Map childLevel1Actual = (Map) childActual.get(CHILD_FIELD_LEVEL_1); + assertEquals(2, childLevel1Actual.size()); + assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD)); + assertNotNull(childLevel1Actual.get(CHILD_FIELD_LEVEL_2)); + Map childLevel2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + List vectors = (List) childLevel2Actual.get(CHILD_LEVEL_2_KNN_FIELD); + assertEquals(100, vectors.size()); + for (Float vector : vectors) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + } + + } + @SneakyThrows public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { Map sourceAndMetadata = new HashMap<>(); @@ -769,6 +840,68 @@ public void testBuildVectorOutput_withNestedList_Level2_successful() { assertNotNull(nestedObj.get(1).get("vectorField")); } + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withNestedListLevel2_withNestedFields_successful() { + Map config = createNestedList2LevelConfiguration(); + IngestDocument ingestDocument = create2LevelNestedListWithNestedFieldsIngestDocument(); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(2, 2, 0.0f, 1.0f); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); + + for (Map singleNestedObjMap : nestedObj) { + Map singleNestedObj = (Map) singleNestedObjMap.get("nestedField"); + assertTrue(singleNestedObj.containsKey("vectorField")); + assertNotNull(singleNestedObj.get("vectorField")); + } + } + + @SuppressWarnings("unchecked") + public void testBuildVectorOutput_withNestedListLevel2_withPartialNullNestedFields_successful() { + Map config = createNestedList2LevelConfiguration(); + IngestDocument ingestDocument = create2LevelNestedListWithNestedFieldsIngestDocument(); + /** + * Ingest doc with below fields + * "nestedField": { + * "nestedField": [ + * { + * "nestedField": { + * "textField": null, + * } + * }, + * { + * "nestedField": { + * "textField": "This is another text field", + * } + * } + * ] + * } + */ + List> nestedList = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); + Map objWithNullText = buildObjMapWithSingleField("textField", null); + Map nestedObjWithNullText = buildObjMapWithSingleField("nestedField", objWithNullText); + nestedList.set(0, nestedObjWithNullText); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(2, 2, 0.0f, 1.0f); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); + + IntStream.range(0, nestedObj.size()).forEachOrdered(index -> { + Map singleNestedObjMap = nestedObj.get(index); + if (index == 0) { + Map singleNestedObj = (Map) singleNestedObjMap.get("nestedField"); + assertFalse(singleNestedObj.containsKey("vectorField")); + assertTrue(singleNestedObj.containsKey("textField")); + } else { + Map singleNestedObj = (Map) singleNestedObjMap.get("nestedField"); + assertTrue(singleNestedObj.containsKey("vectorField")); + assertNotNull(singleNestedObj.get("vectorField")); + } + }); + } + @SuppressWarnings("unchecked") public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_successful() { Map config = createNestedList2LevelConfiguration(); @@ -1043,55 +1176,55 @@ private IngestDocument createNestedMapIngestDocument() { } private Map createNestedListConfiguration() { - Map nestedConfig = new HashMap<>(); - nestedConfig.put("textField", "vectorField"); - Map result = new HashMap<>(); - result.put("nestedField", nestedConfig); - return result; + Map nestedConfig = buildObjMapWithSingleField("textField", "vectorField"); + return buildObjMapWithSingleField("nestedField", nestedConfig); } private Map createNestedList2LevelConfiguration() { - Map nestedConfig = new HashMap<>(); - nestedConfig.put("textField", "vectorField"); - Map nestConfigLevel1 = new HashMap<>(); - nestConfigLevel1.put("nestedField", nestedConfig); - Map result = new HashMap<>(); - result.put("nestedField", nestConfigLevel1); - return result; + Map nestedConfig = buildObjMapWithSingleField("textField", "vectorField"); + Map nestConfigLevel1 = buildObjMapWithSingleField("nestedField", nestedConfig); + return buildObjMapWithSingleField("nestedField", nestConfigLevel1); } private IngestDocument createNestedListIngestDocument() { - HashMap nestedObj1 = new HashMap<>(); - nestedObj1.put("textField", "This is a text field"); - HashMap nestedObj2 = new HashMap<>(); - nestedObj2.put("textField", "This is another text field"); - HashMap nestedList = new HashMap<>(); - nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + Map nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field"); + Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); return new IngestDocument(nestedList, new HashMap<>()); } private IngestDocument createNestedListWithNotEmbeddingFieldIngestDocument() { - HashMap nestedObj1 = new HashMap<>(); - nestedObj1.put("textFieldNotForEmbedding", "This is a text field"); - HashMap nestedObj2 = new HashMap<>(); - nestedObj2.put("textField", "This is another text field"); - HashMap nestedList = new HashMap<>(); - nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + Map nestedObj1 = buildObjMapWithSingleField("textFieldNotForEmbedding", "This is a text field"); + Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); return new IngestDocument(nestedList, new HashMap<>()); } private IngestDocument create2LevelNestedListIngestDocument() { - HashMap nestedObj1 = new HashMap<>(); - nestedObj1.put("textField", "This is a text field"); - HashMap nestedObj2 = new HashMap<>(); - nestedObj2.put("textField", "This is another text field"); - HashMap nestedList = new HashMap<>(); - nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); - HashMap nestedList1 = new HashMap<>(); - nestedList1.put("nestedField", nestedList); + Map nestedObj1 = buildObjMapWithSingleField("textField", "This is a text field"); + Map nestedObj2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + Map nestedList1 = buildObjMapWithSingleField("nestedField", nestedList); return new IngestDocument(nestedList1, new HashMap<>()); } + private IngestDocument create2LevelNestedListWithNestedFieldsIngestDocument() { + Map nestedObj1Level2 = buildObjMapWithSingleField("textField", "This is a text field"); + Map nestedObj1Level1 = buildObjMapWithSingleField("nestedField", nestedObj1Level2); + + Map nestedObj2Level2 = buildObjMapWithSingleField("textField", "This is another text field"); + Map nestedObj2Level1 = buildObjMapWithSingleField("nestedField", nestedObj2Level2); + + Map nestedList = buildObjMapWithSingleField("nestedField", Arrays.asList(nestedObj1Level1, nestedObj2Level1)); + return new IngestDocument(nestedList, new HashMap<>()); + } + + private Map buildObjMapWithSingleField(String fieldName, Object fieldValue) { + Map objMap = new HashMap<>(); + objMap.put(fieldName, fieldValue); + return objMap; + } + private IngestDocument create2LevelNestedListWithNotEmbeddingFieldIngestDocument() { HashMap nestedObj1 = new HashMap<>(); nestedObj1.put("textFieldNotForEmbedding", "This is a text field");