diff --git a/CHANGELOG.md b/CHANGELOG.md index bbdc9c8d9..4cabfe7d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x) ### Features ### Enhancements -* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/) +- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/) +- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811)) ### Bug Fixes - Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 9465b250f..d9f9c7048 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -21,6 +22,8 @@ import lombok.AllArgsConstructor; import lombok.Getter; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.common.collect.Tuple; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.cluster.service.ClusterService; @@ -120,7 +123,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { validateEmbeddingFieldsValue(ingestDocument); - Map processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument); + Map processMap = buildMapWithTargetKeys(ingestDocument); List inferenceList = createInferenceList(processMap); if (inferenceList.size() == 0) { handler.accept(ingestDocument, null); @@ -228,7 +231,7 @@ private List getDataForInference(List i List inferenceList = null; try { validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument()); - processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument()); + processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument()); inferenceList = createInferenceList(processMap); } catch (Exception e) { ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e); @@ -276,15 +279,17 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List } @VisibleForTesting - Map buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) { + Map buildMapWithTargetKeys(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); Map mapWithProcessorKeys = new LinkedHashMap<>(); for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); + Pair processedNestedKey = processNestedKey(fieldMapEntry); + String originalKey = processedNestedKey.getKey(); + Object targetKey = processedNestedKey.getValue(); + if (targetKey instanceof Map) { Map treeRes = new LinkedHashMap<>(); - buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes); mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); } else { mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); @@ -293,20 +298,19 @@ Map buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestD return mapWithProcessorKeys; } - private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (processorKey == null || sourceAndMetadataMap == null) return; + @VisibleForTesting + void buildNestedMap(String parentKey, Object processorKey, Map sourceAndMetadataMap, Map treeRes) { + if (Objects.isNull(processorKey) || Objects.isNull(sourceAndMetadataMap)) { + return; + } if (processorKey instanceof Map) { Map next = new LinkedHashMap<>(); if (sourceAndMetadataMap.get(parentKey) instanceof Map) { for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), + Pair processedNestedKey = processNestedKey(nestedFieldMapEntry); + buildNestedMap( + processedNestedKey.getKey(), + processedNestedKey.getValue(), (Map) sourceAndMetadataMap.get(parentKey), next ); @@ -317,21 +321,46 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); Map map = new LinkedHashMap<>(); map.put(nestedFieldMapEntry.getKey(), listOfStrings); - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - map, - next - ); + buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next); } } - treeRes.put(parentKey, next); + treeRes.merge(parentKey, next, (v1, v2) -> { + if (v1 instanceof Collection && v2 instanceof Collection) { + ((Collection) v1).addAll((Collection) v2); + return v1; + } else if (v1 instanceof Map && v2 instanceof Map) { + ((Map) v1).putAll((Map) v2); + return v1; + } else { + return v2; + } + }); } else { String key = String.valueOf(processorKey); treeRes.put(key, sourceAndMetadataMap.get(parentKey)); } } + /** + * Process the nested key, such as "a.b.c" to "a", "b.c" + * @param nestedFieldMapEntry + * @return A pair of the original key and the target key + */ + @VisibleForTesting + protected Pair processNestedKey(final Map.Entry nestedFieldMapEntry) { + String originalKey = nestedFieldMapEntry.getKey(); + Object targetKey = nestedFieldMapEntry.getValue(); + int nestedDotIndex = originalKey.indexOf('.'); + if (nestedDotIndex != -1) { + Map newTargetKey = new LinkedHashMap<>(); + newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey); + targetKey = newTargetKey; + + originalKey = originalKey.substring(0, nestedDotIndex); + } + return new ImmutablePair<>(originalKey, targetKey); + } + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java index 866a2ab29..caac962e7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor; import com.google.common.collect.ImmutableList; +import org.apache.commons.lang.math.RandomUtils; import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; @@ -58,4 +59,17 @@ protected List> createMockVectorResult() { modelTensorList.add(number7); return modelTensorList; } + + protected List> createRandomOneDimensionalMockVector(int numOfVectors, int vectorDimension, float min, float max) { + List> result = new ArrayList<>(); + for (int i = 0; i < numOfVectors; i++) { + List numbers = new ArrayList<>(); + for (int j = 0; j < vectorDimension; j++) { + Float nextFloat = RandomUtils.nextFloat() * (max - min) + min; + numbers.add(nextFloat); + } + result.add(numbers); + } + return result; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index f963c48fc..98b5a25b6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -16,21 +16,31 @@ import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.collect.ImmutableList; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private static final String INDEX_NAME = "text_embedding_index"; private static final String PIPELINE_NAME = "pipeline-hybrid"; + protected static final String QUERY_TEXT = "hello"; + protected static final String LEVEL_1_FIELD = "nested_passages"; + protected static final String LEVEL_2_FIELD = "level_2"; + protected static final String LEVEL_3_FIELD_TEXT = "level_3_text"; + protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding"; private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI())); private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); + private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); private final String BULK_ITEM_TEMPLATE = Files.readString( Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI()) ); @@ -77,6 +87,66 @@ public void testTextEmbeddingProcessor_batch() throws Exception { } } + public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING); + createTextEmbeddingIndex(); + ingestDocument(INGEST_DOC3, "3"); + + Map sourceMap = (Map) getDocById(INDEX_NAME, "3").get("_source"); + assertNotNull(sourceMap); + assertTrue(sourceMap.containsKey(LEVEL_1_FIELD)); + Map nestedPassages = (Map) sourceMap.get(LEVEL_1_FIELD); + assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD)); + Map level2 = (Map) nestedPassages.get(LEVEL_2_FIELD); + assertEquals(QUERY_TEXT, level2.get(LEVEL_3_FIELD_TEXT)); + assertTrue(level2.containsKey(LEVEL_3_FIELD_EMBEDDING)); + List embeddings = (List) level2.get(LEVEL_3_FIELD_EMBEDDING); + assertEquals(768, embeddings.size()); + for (Double embedding : embeddings) { + assertTrue(embedding >= 0.0 && embedding <= 1.0); + } + + NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_EMBEDDING, + QUERY_TEXT, + "", + modelId, + 10, + null, + null, + null, + null, + null + ); + QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD, + neuralQueryBuilderQuery, + ScoreMode.Total + ); + QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); + + Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 1); + assertNotNull(searchResponseAsMap); + + Map hits = (Map) searchResponseAsMap.get("hits"); + assertNotNull(hits); + + assertEquals(1.0, hits.get("max_score")); + List> listOfHits = (List>) hits.get("hits"); + assertNotNull(listOfHits); + assertEquals(1, listOfHits.size()); + Map hitsInner = listOfHits.get(0); + assertEquals("3", hitsInner.get("_id")); + assertEquals(1.0, hitsInner.get("_score")); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + private String uploadTextEmbeddingModel() throws Exception { String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); return registerModelGroupAndUploadModel(requestBody); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index bff578ad7..9a5e8aa76 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -18,14 +18,17 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Arrays; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; +import org.apache.commons.lang3.tuple.Pair; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; @@ -50,6 +53,11 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { + protected static final String PARENT_FIELD = "parent"; + protected static final String CHILD_FIELD_LEVEL_1 = "child_level1"; + protected static final String CHILD_FIELD_LEVEL_2 = "child_level2"; + protected static final String CHILD_LEVEL_2_TEXT_FIELD_VALUE = "text_field_value"; + protected static final String CHILD_LEVEL_2_KNN_FIELD = "test3_knn"; @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; @@ -77,7 +85,7 @@ private TextEmbeddingProcessor createInstanceWithLevel2MapConfig() { config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put( TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of("key1", ImmutableMap.of("test1", "test1_knn"), "key2", ImmutableMap.of("test3", "test3_knn")) + ImmutableMap.of("key1", ImmutableMap.of("test1", "test1_knn"), "key2", ImmutableMap.of("test3", CHILD_LEVEL_2_KNN_FIELD)) ); return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @@ -285,6 +293,94 @@ public void testExecute_withMapTypeInput_successful() { } + @SneakyThrows + public void testNestedFieldInMapping_withMapTypeInput_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_TEXT_FIELD_VALUE); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + sourceAndMetadata.put(PARENT_FIELD, childLevel1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_FIELD_LEVEL_2)), + CHILD_LEVEL_2_KNN_FIELD + ) + ); + TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + + List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + assertNotNull(ingestDocument); + assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)); + Map childLevel1AfterProcessor = (Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel2AfterProcessor = (Map) childLevel1AfterProcessor.get(CHILD_FIELD_LEVEL_1); + assertEquals(CHILD_LEVEL_2_TEXT_FIELD_VALUE, childLevel2AfterProcessor.get(CHILD_FIELD_LEVEL_2)); + assertNotNull(childLevel2AfterProcessor.get(CHILD_LEVEL_2_KNN_FIELD)); + List vectors = (List) childLevel2AfterProcessor.get(CHILD_LEVEL_2_KNN_FIELD); + assertEquals(100, vectors.size()); + for (Float vector : vectors) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + } + + @SneakyThrows + public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_TEXT_FIELD_VALUE); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + sourceAndMetadata.put(PARENT_FIELD, childLevel1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1)), + Map.of(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD) + ) + ); + TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + + List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + assertNotNull(ingestDocument); + assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)); + Map childLevel1AfterProcessor = (Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel2AfterProcessor = (Map) childLevel1AfterProcessor.get(CHILD_FIELD_LEVEL_1); + assertEquals(CHILD_LEVEL_2_TEXT_FIELD_VALUE, childLevel2AfterProcessor.get(CHILD_FIELD_LEVEL_2)); + assertNotNull(childLevel2AfterProcessor.get(CHILD_LEVEL_2_KNN_FIELD)); + List vectors = (List) childLevel2AfterProcessor.get(CHILD_LEVEL_2_KNN_FIELD); + assertEquals(100, vectors.size()); + for (Float vector : vectors) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + } + public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { Map map1 = ImmutableMap.of("test1", "test2"); Map map2 = ImmutableMap.of("test3", 209.3D); @@ -396,7 +492,7 @@ public void testProcessResponse_successful() throws Exception { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); @@ -409,7 +505,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); // To assert the order is not changed between config map and generated map. List configValueList = new LinkedList<>(config.values()); @@ -435,23 +531,51 @@ public void testBuildVectorOutput_withNestedMap_successful() { Map config = createNestedMapConfiguration(); IngestDocument ingestDocument = createNestedMapIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); - List> modelTensorList = createMockVectorResult(); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); + List> modelTensorList = createRandomOneDimensionalMockVector(2, 100, 0.0f, 1.0f); processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + /** + * "favorites": { + * "favorite": { + * "movie": "matrix", + * "actor": "Charlie Chaplin", + * "games" : { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * } + * } + * } + */ Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); assertNotNull(favoritesMap); - Map favoriteGames = (Map) favoritesMap.get("favorite.games"); + Map favorites = (Map) favoritesMap.get("favorite"); + assertNotNull(favorites); + + Map favoriteGames = (Map) favorites.get("games"); assertNotNull(favoriteGames); Map adventure = (Map) favoriteGames.get("adventure"); - Object actionGamesKnn = adventure.get("with.action.knn"); - assertNotNull(actionGamesKnn); + List adventureKnnVector = (List) adventure.get("with_action_knn"); + assertNotNull(adventureKnnVector); + assertEquals(100, adventureKnnVector.size()); + for (float vector : adventureKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } + + List favoriteKnnVector = (List) favorites.get("favorite_movie_knn"); + assertNotNull(favoriteKnnVector); + assertEquals(100, favoriteKnnVector.size()); + for (float vector : favoriteKnnVector) { + assertTrue(vector >= 0.0f && vector <= 1.0f); + } } public void testBuildVectorOutput_withNestedList_successful() { Map config = createNestedListConfiguration(); IngestDocument ingestDocument = createNestedListIngestDocument(); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); List> modelTensorList = createMockVectorResult(); textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); @@ -465,7 +589,7 @@ public void testBuildVectorOutput_withNestedList_Level2_successful() { Map config = createNestedList2LevelConfiguration(); IngestDocument ingestDocument = create2LevelNestedListIngestDocument(); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeys(ingestDocument); List> modelTensorList = createMockVectorResult(); textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map nestedLevel1 = (Map) ingestDocument.getSourceAndMetadata().get("nestedField"); @@ -480,7 +604,7 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); @@ -556,6 +680,100 @@ public void test_batchExecute_exception() { } } + public void testParsingNestedField_whenNestedFieldsConfigured_thenSuccessful() { + Map config = createNestedMapConfiguration(); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + /** + * Assert that mapping + * "favorites": { + * "favorite.movie": "favorite_movie_knn", + * "favorite.games": { + * "adventure.action": "with_action_knn" + * } + * } + * has been transformed to structure: + * "favorites": { + * "favorite": { + * "movie": "favorite_movie_knn", + * "games": { + * "adventure": { + * "action": "with_action_knn" + * } + * } + * } + * } + */ + assertMapWithNestedFields( + processor.processNestedKey( + config.entrySet().stream().filter(entry -> entry.getKey().equals("favorites")).findAny().orElseThrow() + ), + List.of("favorites"), + Optional.empty() + ); + + Map favorites = (Map) config.get("favorites"); + + assertMapWithNestedFields( + processor.processNestedKey( + favorites.entrySet().stream().filter(entry -> entry.getKey().equals("favorite.games")).findAny().orElseThrow() + ), + List.of("favorite", "games"), + Optional.of("favorite_movie_knn") + ); + + assertMapWithNestedFields( + processor.processNestedKey( + favorites.entrySet().stream().filter(entry -> entry.getKey().equals("favorite.movie")).findAny().orElseThrow() + ), + List.of("favorite", "movie"), + Optional.empty() + ); + + Map adventureActionMap = (Map) favorites.get("favorite.games"); + assertMapWithNestedFields( + processor.processNestedKey( + adventureActionMap.entrySet().stream().filter(entry -> entry.getKey().equals("adventure.action")).findAny().orElseThrow() + ), + List.of("adventure", "action"), + Optional.of("with_action_knn") + ); + } + + public void testBuildingOfNestedMap_whenHasNestedMapping_thenSuccessful() { + /** + * assert based on following structure: + * "nestedField": { + * "nestedField": { + * "textField": "vectorField" + * } + * } + */ + Map config = createNestedList2LevelConfiguration(); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + Map resultAsTree = new LinkedHashMap<>(); + processor.buildNestedMap("nestedField", config.get("nestedField"), config, resultAsTree); + assertNotNull(resultAsTree); + Map actualMapLevel1 = (Map) resultAsTree.get("nestedField"); + assertEquals(1, actualMapLevel1.size()); + assertEquals(Map.of("vectorField", "vectorField"), actualMapLevel1.get("nestedField")); + } + + private void assertMapWithNestedFields(Pair actual, List expectedKeys, Optional expectedFinalValue) { + assertNotNull(actual); + assertEquals(expectedKeys.get(0), actual.getKey()); + Map actualValue = (Map) actual.getValue(); + for (int i = 1; i < expectedKeys.size(); i++) { + assertTrue(actualValue.containsKey(expectedKeys.get(i))); + if (actualValue.get(expectedKeys.get(i)) instanceof Map) { + actualValue = (Map) actualValue.get(expectedKeys.get(i)); + } else if (expectedFinalValue.isPresent()) { + assertEquals(expectedFinalValue.get(), actualValue.get(expectedKeys.get(i))); + } else { + break; + } + } + } + @SneakyThrows private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map fieldMap) { Map registry = new HashMap<>(); @@ -576,20 +794,21 @@ private Map createPlainStringConfiguration() { return config; } + /** + * Create following mapping + * "favorites": { + * "favorite.movie": "favorite_movie_knn", + * "favorite.games": { + * "adventure.action": "with_action_knn" + * } + * } + */ private Map createNestedMapConfiguration() { Map adventureGames = new HashMap<>(); - adventureGames.put("with.action", "with.action.knn"); - adventureGames.put("with.reaction", "with.reaction.knn"); - Map puzzleGames = new HashMap<>(); - puzzleGames.put("maze", "maze.knn"); - puzzleGames.put("card", "card.knn"); - Map favoriteGames = new HashMap<>(); - favoriteGames.put("adventure", adventureGames); - favoriteGames.put("puzzle", puzzleGames); + adventureGames.put("adventure.action", "with_action_knn"); Map favorite = new HashMap<>(); - favorite.put("favorite.movie", "favorite.movie.knn"); - favorite.put("favorite.games", favoriteGames); - favorite.put("favorite.songs", "favorite.songs.knn"); + favorite.put("favorite.movie", "favorite_movie_knn"); + favorite.put("favorite.games", adventureGames); Map result = new HashMap<>(); result.put("favorites", favorite); return result; @@ -606,23 +825,33 @@ private IngestDocument createPlainIngestDocument() { return new IngestDocument(result, new HashMap<>()); } + /** + * Create following document + * "favorites": { + * "favorite": { + * "movie": "matrix", + * "actor": "Charlie Chaplin", + * "games" : { + * "adventure": { + * "action": "overwatch", + * "rpg": "elden ring" + * } + * } + * } + * } + */ private IngestDocument createNestedMapIngestDocument() { Map adventureGames = new HashMap<>(); - List actionGames = new ArrayList<>(); - actionGames.add("jojo world"); - actionGames.add(null); - adventureGames.put("with.action", actionGames); - adventureGames.put("with.reaction", "overwatch"); - Map puzzleGames = new HashMap<>(); - puzzleGames.put("maze", "zelda"); - puzzleGames.put("card", "hearthstone"); - Map favoriteGames = new HashMap<>(); - favoriteGames.put("adventure", adventureGames); - favoriteGames.put("puzzle", puzzleGames); + adventureGames.put("action", "overwatch"); + adventureGames.put("rpg", "elden ring"); + Map favGames = new HashMap<>(); + favGames.put("adventure", adventureGames); + Map favorites = new HashMap<>(); + favorites.put("movie", "matrix"); + favorites.put("games", favGames); + favorites.put("actor", "Charlie Chaplin"); Map favorite = new HashMap<>(); - favorite.put("favorite.movie", "favorite.movie.knn"); - favorite.put("favorite.games", favoriteGames); - favorite.put("favorite.songs", "In The Name Of Father"); + favorite.put("favorite", favorites); Map result = new HashMap<>(); result.put("favorites", favorite); return new IngestDocument(result, new HashMap<>()); diff --git a/src/test/resources/processor/IndexMappings.json b/src/test/resources/processor/IndexMappings.json index ffa5cea64..79eb34ce4 100644 --- a/src/test/resources/processor/IndexMappings.json +++ b/src/test/resources/processor/IndexMappings.json @@ -102,6 +102,27 @@ "m": 24 } } + }, + "level_2": { + "type": "nested", + "properties": { + "level_3_text": { + "type": "text" + }, + "level_3_embedding": { + "type": "knn_vector", + "dimension": 768, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "lucene", + "parameters": { + "ef_construction": 128, + "m": 24 + } + } + } + } } } } diff --git a/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMapping.json b/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMapping.json new file mode 100644 index 000000000..13bae8776 --- /dev/null +++ b/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMapping.json @@ -0,0 +1,19 @@ +{ + "description": "text embedding pipeline for hybrid", + "processors": [ + { + "text_embedding": { + "model_id": "%s", + "field_map": { + "title": "title_knn", + "favor_list": "favor_list_knn", + "favorites": { + "game": "game_knn", + "movie": "movie_knn" + }, + "nested_passages.level_2.level_3_text": "level_3_embedding" + } + } + } + ] +} diff --git a/src/test/resources/processor/ingest_doc3.json b/src/test/resources/processor/ingest_doc3.json new file mode 100644 index 000000000..8eae12fe2 --- /dev/null +++ b/src/test/resources/processor/ingest_doc3.json @@ -0,0 +1,20 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": + { + "level_2": + { + "level_3_text": "hello" + } + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index d85a70a1d..9e2da699c 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -81,7 +81,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ProcessorType.SPARSE_ENCODING, "processor/SparseEncodingPipelineConfiguration.json", ProcessorType.TEXT_IMAGE_EMBEDDING, - "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" + "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", + ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, + "processor/PipelineConfigurationWithNestedFieldsMapping.json" ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; @@ -1344,6 +1346,7 @@ protected Object validateDocCountAndInfo( */ protected enum ProcessorType { TEXT_EMBEDDING, + TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, TEXT_IMAGE_EMBEDDING, SPARSE_ENCODING }