From 25f57ad377cfea355a22eac7e59e6e2288139317 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 09:38:37 +0800 Subject: [PATCH] fix map type validation issue in processors (#687) * fix map type validation issue in processors Signed-off-by: zane-neo * fix test failures on main branch Signed-off-by: zane-neo * Fix potential NPE issue in chunking processor; add changee log Signed-off-by: zane-neo * Fix failure tests Signed-off-by: zane-neo * Address comments and add one more UT to cover uncovered line Signed-off-by: zane-neo * Address comments Signed-off-by: zane-neo * Add more UTs Signed-off-by: zane-neo * fix failure ITs Signed-off-by: zane-neo * Add public method with default depth parameter value Signed-off-by: zane-neo * rebase latest code Signed-off-by: zane-neo * address comments Signed-off-by: zane-neo * address comment Signed-off-by: zane-neo --------- Signed-off-by: zane-neo (cherry picked from commit 54ac672f6df792a8eb5746bfe835f474a7278227) --- CHANGELOG.md | 1 + .../neuralsearch/plugin/NeuralSearch.java | 4 +- .../processor/InferenceProcessor.java | 79 +++------ .../processor/SparseEncodingProcessor.java | 6 +- .../processor/TextChunkingProcessor.java | 72 ++------ .../processor/TextEmbeddingProcessor.java | 6 +- .../TextImageEmbeddingProcessor.java | 70 ++------ .../SparseEncodingProcessorFactory.java | 7 +- .../TextEmbeddingProcessorFactory.java | 12 +- .../util/ProcessorDocumentUtils.java | 167 ++++++++++++++++++ .../processor/InferenceProcessorTestCase.java | 2 + .../processor/InferenceProcessorTests.java | 6 +- .../SparseEncodingProcessorTests.java | 40 ++++- .../processor/TextChunkingProcessorTests.java | 46 ++--- .../TextEmbeddingProcessorTests.java | 122 +++++++++---- .../TextImageEmbeddingProcessorTests.java | 4 + .../util/ProcessorDocumentUtilsTests.java | 83 +++++++++ .../util/ProcessorDocumentUtils.json | 161 +++++++++++++++++ 18 files changed, 648 insertions(+), 240 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java create mode 100644 src/test/resources/util/ProcessorDocumentUtils.json diff --git a/CHANGELOG.md b/CHANGELOG.md index a9cd52ed4..eabbf72f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Optimize max score calculation in the Query Phase of the Hybrid Search ([765](https://github.com/opensearch-project/neural-search/pull/765)) ### Bug Fixes - Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756)) +- Fix map type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661)) ### Infrastructure - Disable memory circuit breaker for integ tests ([#770](https://github.com/opensearch-project/neural-search/pull/770)) ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index a8ce31e0d..f74352012 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -113,9 +113,9 @@ public Map getProcessors(Processor.Parameters paramet clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Map.of( TextEmbeddingProcessor.TYPE, - new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), SparseEncodingProcessor.TYPE, - new SparseEncodingProcessorFactory(clientAccessor, parameters.env), + new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), TextImageEmbeddingProcessor.TYPE, new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), TextChunkingProcessor.TYPE, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 4956a445c..9465b250f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -15,7 +15,6 @@ import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -24,8 +23,9 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.common.collect.Tuple; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; @@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; /** * The abstract class for text processing use cases. Users provide a field name map and a model id. @@ -60,6 +61,7 @@ public abstract class InferenceProcessor extends AbstractProcessor { protected final MLCommonsClientAccessor mlCommonsClientAccessor; private final Environment environment; + private final ClusterService clusterService; public InferenceProcessor( String tag, @@ -69,18 +71,19 @@ public InferenceProcessor( String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, - Environment environment + Environment environment, + ClusterService clusterService ) { super(tag, description); this.type = type; if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); validateEmbeddingConfiguration(fieldMap); - this.listTypeNestedMapKey = listTypeNestedMapKey; this.modelId = modelId; this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; + this.clusterService = clusterService; } private void validateEmbeddingConfiguration(Map fieldMap) { @@ -117,12 +120,12 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { validateEmbeddingFieldsValue(ingestDocument); - Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(ProcessMap); + Map processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(processMap); if (inferenceList.size() == 0) { handler.accept(ingestDocument, null); } else { - doExecute(ingestDocument, ProcessMap, inferenceList, handler); + doExecute(ingestDocument, processMap, inferenceList, handler); } } catch (Exception e) { handler.accept(null, e); @@ -225,7 +228,7 @@ private List getDataForInference(List i List inferenceList = null; try { validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument()); - processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument()); + processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument()); inferenceList = createInferenceList(processMap); } catch (Exception e) { ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e); @@ -273,7 +276,7 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List } @VisibleForTesting - Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + Map buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); Map mapWithProcessorKeys = new LinkedHashMap<>(); for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { @@ -331,54 +334,16 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - for (Object value : (List) sourceValue) { - if (value instanceof Map) { - validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1); - } else if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); - } - } + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + ProcessorDocumentUtils.validateMapTypeValue( + FIELD_MAP_FIELD, + sourceAndMetadataMap, + fieldMap, + indexName, + clusterService, + environment, + false + ); } protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 9e2336cf6..e83bd8233 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -9,6 +9,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; @@ -33,9 +34,10 @@ public SparseEncodingProcessor( String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, - Environment environment + Environment environment, + ClusterService clusterService ) { - super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java index 4338139d9..49435746c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java @@ -17,7 +17,6 @@ import org.opensearch.env.Environment; import org.opensearch.cluster.service.ClusterService; import org.opensearch.index.analysis.AnalysisRegistry; -import org.opensearch.index.mapper.MapperService; import org.opensearch.index.IndexSettings; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; @@ -25,6 +24,7 @@ import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory; import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker; +import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD; import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT; @@ -164,7 +164,16 @@ private int getMaxTokenCount(final Map sourceAndMetadataMap) { @Override public IngestDocument execute(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - validateFieldsValue(sourceAndMetadataMap); + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + ProcessorDocumentUtils.validateMapTypeValue( + FIELD_MAP_FIELD, + sourceAndMetadataMap, + fieldMap, + indexName, + clusterService, + environment, + true + ); // fixed token length algorithm needs runtime parameter max_token_count for tokenization Map runtimeParameters = new HashMap<>(); int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap); @@ -176,59 +185,6 @@ public IngestDocument execute(final IngestDocument ingestDocument) { return ingestDocument; } - private void validateFieldsValue(final Map sourceAndMetadataMap) { - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (Objects.nonNull(sourceValue)) { - String sourceKey = embeddingFieldsEntry.getKey(); - if (sourceValue instanceof List || sourceValue instanceof Map) { - validateNestedTypeValue(sourceKey, sourceValue, 1); - } else if (!(sourceValue instanceof String)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", sourceKey) - ); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) { - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", sourceKey) - ); - } else if (sourceValue instanceof List) { - validateListTypeValue(sourceKey, sourceValue, maxDepth); - } else if (sourceValue instanceof Map) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, maxDepth + 1)); - } else if (!(sourceValue instanceof String)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", sourceKey) - ); - } - } - - @SuppressWarnings({ "rawtypes" }) - private void validateListTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) { - for (Object value : (List) sourceValue) { - if (value instanceof Map) { - validateNestedTypeValue(sourceKey, value, maxDepth + 1); - } else if (value == null) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", sourceKey) - ); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", sourceKey) - ); - } - } - } - @SuppressWarnings("unchecked") private int getChunkStringCountFromMap(Map sourceAndMetadataMap, final Map fieldMap) { int chunkStringCount = 0; @@ -334,7 +290,13 @@ private List chunkLeafType(final Object value, final Map // leaf type means null, String or List // the result should be an empty list when the input is null List result = new ArrayList<>(); + if (value == null) { + return result; + } if (value instanceof String) { + if (StringUtils.isBlank(String.valueOf(value))) { + return result; + } result = chunkString(value.toString(), runTimeParameters); } else if (isListOfString(value)) { result = chunkList((List) value, runTimeParameters); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 7e765624e..f5b710530 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -9,6 +9,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; @@ -32,9 +33,10 @@ public TextEmbeddingProcessor( String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, - Environment environment + Environment environment, + ClusterService clusterService ) { - super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index 09fcf3d97..e808869f9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -12,15 +12,12 @@ import java.util.Objects; import java.util.Set; import java.util.function.BiConsumer; -import java.util.function.Supplier; import org.apache.commons.lang3.StringUtils; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.IndexFieldMapper; -import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -28,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; /** * This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use, @@ -51,6 +49,7 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { private final Map fieldMap; private final MLCommonsClientAccessor mlCommonsClientAccessor; + private final Environment environment; private final ClusterService clusterService; @@ -173,61 +172,16 @@ Map buildTextEmbeddingResult(final String knnKey, List mo private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - String mappedSourceKey = embeddingFieldsEntry.getValue(); - Object sourceValue = sourceAndMetadataMap.get(mappedSourceKey); - if (Objects.isNull(sourceValue)) { - continue; - } - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); - validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1, indexName); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + mappedSourceKey + "] has empty string value, can not process it"); - } - - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue( - final String sourceKey, - final Object sourceValue, - final Supplier maxDepthSupplier, - final String indexName - ) { - int maxDepth = maxDepthSupplier.get(); - Settings indexSettings = clusterService.state().metadata().index(indexName).getSettings(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings)) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, (List) sourceValue); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1, indexName)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(final String sourceKey, final List sourceValue) { - for (Object value : sourceValue) { - if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); - } - } + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + ProcessorDocumentUtils.validateMapTypeValue( + FIELD_MAP_FIELD, + sourceAndMetadataMap, + fieldMap, + indexName, + clusterService, + environment, + false + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 95b2803a0..8a294458a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -12,6 +12,7 @@ import java.util.Map; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -26,10 +27,12 @@ public class SparseEncodingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; private final Environment environment; + private final ClusterService clusterService; - public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) { this.clientAccessor = clientAccessor; this.environment = environment; + this.clusterService = clusterService; } @Override @@ -42,6 +45,6 @@ public SparseEncodingProcessor create( String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new SparseEncodingProcessor(processorTag, description, modelId, fieldMap, clientAccessor, environment); + return new SparseEncodingProcessor(processorTag, description, modelId, fieldMap, clientAccessor, environment, clusterService); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 7802cb1f6..d38bf21df 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -12,6 +12,7 @@ import java.util.Map; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -26,9 +27,16 @@ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final Environment environment; - public TextEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { + private final ClusterService clusterService; + + public TextEmbeddingProcessorFactory( + final MLCommonsClientAccessor clientAccessor, + final Environment environment, + final ClusterService clusterService + ) { this.clientAccessor = clientAccessor; this.environment = environment; + this.clusterService = clusterService; } @Override @@ -40,6 +48,6 @@ public TextEmbeddingProcessor create( ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment, clusterService); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java new file mode 100644 index 000000000..b209dbed7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple + * pipeline processors. + */ +public class ProcessorDocumentUtils { + + /** + * Validates a map type value recursively up to a specified depth. Supports Map type, List type and String type. + * If current sourceValue is Map or List type, recursively validates its values, otherwise validates its value. + * + * @param sourceKey the key of the source map being validated, the first level is always the "field_map" key. + * @param sourceValue the source map being validated, the first level is always the sourceAndMetadataMap. + * @param fieldMap the configuration map for validation, the first level is always the value of "field_map" in the processor configuration. + * @param clusterService cluster service passed from OpenSearch core. + * @param environment environment passed from OpenSearch core. + * @param indexName the maximum allowed depth for recursion + * @param allowEmpty flag to allow empty values in map type validation. + */ + public static void validateMapTypeValue( + final String sourceKey, + final Map sourceValue, + final Object fieldMap, + final String indexName, + final ClusterService clusterService, + final Environment environment, + final boolean allowEmpty + ) { + validateMapTypeValue(sourceKey, sourceValue, fieldMap, 1, indexName, clusterService, environment, allowEmpty); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static void validateMapTypeValue( + final String sourceKey, + final Map sourceValue, + final Object fieldMap, + final long depth, + final String indexName, + final ClusterService clusterService, + final Environment environment, + final boolean allowEmpty + ) { + if (Objects.isNull(sourceValue)) { // allow map type value to be null. + return; + } + validateDepth(sourceKey, depth, indexName, clusterService, environment); + if (!(fieldMap instanceof Map)) { // source value is map type means configuration has to be map type + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] configuration doesn't match actual value type, configuration type is: %s, actual value type is: %s", + sourceKey, + fieldMap.getClass().getName(), + sourceValue.getClass().getName() + ) + ); + } + // next level validation, only validate the keys in configuration. + ((Map) fieldMap).forEach((key, nextFieldMap) -> { + Object nextSourceValue = sourceValue.get(key); + if (nextSourceValue != null) { + if (nextSourceValue instanceof List) { + validateListTypeValue( + key, + (List) nextSourceValue, + fieldMap, + depth + 1, + indexName, + clusterService, + environment, + allowEmpty + ); + } else if (nextSourceValue instanceof Map) { + validateMapTypeValue( + key, + (Map) nextSourceValue, + nextFieldMap, + depth + 1, + indexName, + clusterService, + environment, + allowEmpty + ); + } else if (!(nextSourceValue instanceof String)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "map type field [%s] is neither string nor nested type, cannot process it", key)); + } else if (!allowEmpty && StringUtils.isBlank((String) nextSourceValue)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "map type field [%s] has empty string value, cannot process it", key)); + } + } + }); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static void validateListTypeValue( + final String sourceKey, + final List sourceValue, + final Object fieldMap, + final long depth, + final String indexName, + final ClusterService clusterService, + final Environment environment, + final boolean allowEmpty + ) { + validateDepth(sourceKey, depth, indexName, clusterService, environment); + if (CollectionUtils.isEmpty(sourceValue)) { + return; + } + for (Object element : sourceValue) { + if (Objects.isNull(element)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", sourceKey)); + } + if (element instanceof List) { // nested list case. + throw new IllegalArgumentException(String.format(Locale.ROOT, "list type field [%s] is nested list type, cannot process it", sourceKey)); + } else if (element instanceof Map) { + validateMapTypeValue( + sourceKey, + (Map) element, + ((Map) fieldMap).get(sourceKey), + depth + 1, + indexName, + clusterService, + environment, + allowEmpty + ); + } else if (!(element instanceof String)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "list type field [%s] has non string value, cannot process it", sourceKey)); + } else if (!allowEmpty && StringUtils.isBlank(element.toString())) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "list type field [%s] has empty string, cannot process it", sourceKey)); + } + } + } + + private static void validateDepth( + String sourceKey, + long depth, + String indexName, + ClusterService clusterService, + Environment environment + ) { + Settings settings = Optional.ofNullable(clusterService.state().metadata().index(indexName)) + .map(IndexMetadata::getSettings) + .orElse(environment.settings()); + long maxDepth = MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(settings); + if (depth > maxDepth) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "map type field [%s] reaches max depth limit, cannot process it", sourceKey)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java index 05a327b82..866a2ab29 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor; import com.google.common.collect.ImmutableList; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.test.OpenSearchTestCase; @@ -21,6 +22,7 @@ protected List createIngestDocumentWrappers(int count) { for (int i = 0; i < count; ++i) { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); wrapperList.add(new IngestDocumentWrapper(i, new IngestDocument(sourceAndMetadata, new HashMap<>()), null)); } return wrapperList; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index 43c2ba1fb..d08f6c3f1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -7,6 +7,7 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; @@ -24,6 +25,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -33,6 +35,8 @@ public class InferenceProcessorTests extends InferenceProcessorTestCase { private MLCommonsClientAccessor clientAccessor; private Environment environment; + private ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS); + private static final String TAG = "tag"; private static final String TYPE = "type"; private static final String DESCRIPTION = "description"; @@ -175,7 +179,7 @@ private class TestInferenceProcessor extends InferenceProcessor { Exception exception; public TestInferenceProcessor(List vectors, Exception exception) { - super(TAG, DESCRIPTION, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment); + super(TAG, DESCRIPTION, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService); this.vectors = vectors; this.exception = exception; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 5b85ec923..7460390de 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -32,9 +33,11 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; @@ -51,10 +54,12 @@ public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { private MLCommonsClientAccessor mlCommonsClientAccessor; @Mock - private Environment env; + private Environment environment; + + private ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS); @InjectMocks - private SparseEncodingProcessorFactory SparseEncodingProcessorFactory; + private SparseEncodingProcessorFactory sparseEncodingProcessorFactory; private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -62,7 +67,7 @@ public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); - when(env.settings()).thenReturn(settings); + when(clusterService.state().metadata().index(anyString()).getSettings()).thenReturn(settings); } @SneakyThrows @@ -71,11 +76,12 @@ private SparseEncodingProcessor createInstance() { Map config = new HashMap<>(); config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - return SparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -96,10 +102,15 @@ public void testExecute_successful() { @SneakyThrows public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); - SparseEncodingProcessorFactory sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(accessor, env); + SparseEncodingProcessorFactory sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory( + accessor, + environment, + clusterService + ); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -115,6 +126,7 @@ public void testExecute_withListTypeInput_successful() { List list1 = ImmutableList.of("test1", "test2", "test3"); List list2 = ImmutableList.of("test4", "test5", "test6"); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", list1); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -134,6 +146,7 @@ public void testExecute_withListTypeInput_successful() { public void testExecute_MLClientAccessorThrowFail_handlerFailure() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -150,14 +163,25 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } + @SneakyThrows public void testExecute_withMapTypeInput_successful() { - Map map1 = ImmutableMap.of("test1", "test2"); - Map map2 = ImmutableMap.of("test4", "test5"); + Map map1 = new HashMap<>(); + map1.put("test1", "test2"); + Map map2 = new HashMap<>(); + map2.put("test4", "test5"); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - SparseEncodingProcessor processor = createInstance(); + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + SparseEncodingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of("key1", Map.of("test1", "test1_knn"), "key2", Map.of("test4", "test4_knn")) + ); + SparseEncodingProcessor processor = sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java index efc9745ed..433e51ef5 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java @@ -13,7 +13,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import java.util.Set; import static java.util.Collections.singletonList; @@ -80,7 +79,11 @@ public Map> getTokeniz public void setup() { Metadata metadata = mock(Metadata.class); Environment environment = mock(Environment.class); - Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + Settings settings = Settings.builder() + .put("index.mapping.depth.limit", 20) + .put("index.analyze.max_token_count", 10000) + .put("index.number_of_shards", 1) + .build(); when(environment.settings()).thenReturn(settings); ClusterState clusterState = mock(ClusterState.class); ClusterService clusterService = mock(ClusterService.class); @@ -357,13 +360,11 @@ private Map createSourceDataInvalidNestedMap() { private Map createMaxDepthLimitExceedMap(int maxDepth) { if (maxDepth > 21) { - return null; + return Map.of(INPUT_FIELD, "mapped"); } Map resultMap = new HashMap<>(); Map innerMap = createMaxDepthLimitExceedMap(maxDepth + 1); - if (Objects.nonNull(innerMap)) { - resultMap.put(INPUT_FIELD, innerMap); - } + resultMap.put(INPUT_FIELD, innerMap); return resultMap; } @@ -597,7 +598,7 @@ public void testExecute_withFixedTokenLength_andSourceDataInvalidType_thenFail() () -> processor.execute(ingestDocument) ); assertEquals( - String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", INPUT_FIELD), + String.format(Locale.ROOT, "map type field [%s] is neither string nor nested type, cannot process it", INPUT_FIELD), illegalArgumentException.getMessage() ); } @@ -630,7 +631,7 @@ public void testExecute_withFixedTokenLength_andSourceDataListWithInvalidType_th () -> processor.execute(ingestDocument) ); assertEquals( - String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", INPUT_FIELD), + String.format(Locale.ROOT, "list type field [%s] has non string value, cannot process it", INPUT_FIELD), illegalArgumentException.getMessage() ); } @@ -855,16 +856,16 @@ public void testExecute_withFixedTokenLength_andFieldMapNestedMapMultipleField_e @SneakyThrows public void testExecute_withFixedTokenLength_andMaxDepthLimitExceedFieldMap_thenFail() { - TextChunkingProcessor processor = createFixedTokenLengthInstance(createNestedFieldMapSingleField()); - IngestDocument ingestDocument = createIngestDocumentWithNestedSourceData(createMaxDepthLimitExceedMap(0)); + Map map = createMaxDepthLimitExceedMap(0); + Map config = new HashMap<>(); + config.put(INPUT_NESTED_FIELD_KEY, map.get("body")); + TextChunkingProcessor processor = createFixedTokenLengthInstance(config); + IngestDocument ingestDocument = createIngestDocumentWithNestedSourceData(map); IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, () -> processor.execute(ingestDocument) ); - assertEquals( - String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", INPUT_NESTED_FIELD_KEY), - illegalArgumentException.getMessage() - ); + assertEquals("map type field [body] reaches max depth limit, cannot process it", illegalArgumentException.getMessage()); } @SneakyThrows @@ -876,7 +877,7 @@ public void testExecute_withFixedTokenLength_andFieldMapNestedMapSingleField_the () -> processor.execute(ingestDocument) ); assertEquals( - String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", INPUT_NESTED_FIELD_KEY), + "[body] configuration doesn't match actual value type, configuration type is: java.lang.String, actual value type is: java.util.ImmutableCollections$Map1", illegalArgumentException.getMessage() ); } @@ -906,15 +907,18 @@ public void testExecute_withFixedTokenLength_andFieldMapNestedMapSingleField_sou } @SneakyThrows - public void testExecute_withFixedTokenLength_andSourceDataListWithHybridType_thenSucceed() { + public void testExecute_withFixedTokenLength_andSourceDataListWithHybridType_thenFail() { TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); List sourceDataList = createSourceDataListWithHybridType(); IngestDocument ingestDocument = createIngestDocumentWithSourceData(sourceDataList); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey(INPUT_FIELD); - Object listResult = document.getSourceAndMetadata().get(OUTPUT_FIELD); - assert (listResult instanceof List); - assertEquals(((List) listResult).size(), 0); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assertEquals( + "[body] configuration doesn't match actual value type, configuration type is: java.lang.String, actual value type is: com.google.common.collect.RegularImmutableMap", + illegalArgumentException.getMessage() + ); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 752615057..bff578ad7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -6,6 +6,7 @@ import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -31,9 +32,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; @@ -51,7 +54,9 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { private MLCommonsClientAccessor mlCommonsClientAccessor; @Mock - private Environment env; + private Environment environment; + + private ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS); @InjectMocks private TextEmbeddingProcessorFactory textEmbeddingProcessorFactory; @@ -62,15 +67,27 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); - when(env.settings()).thenReturn(settings); + when(clusterService.state().metadata().index(anyString()).getSettings()).thenReturn(settings); } @SneakyThrows - private TextEmbeddingProcessor createInstance() { + private TextEmbeddingProcessor createInstanceWithLevel2MapConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of("key1", ImmutableMap.of("test1", "test1_knn"), "key2", ImmutableMap.of("test3", "test3_knn")) + ); + return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithLevel1MapConfig() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1_knn", "key2", "key2_knn")); return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @@ -104,10 +121,11 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -129,7 +147,11 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); - TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory(accessor, env); + TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory( + accessor, + environment, + clusterService + ); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -144,10 +166,15 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep @SneakyThrows public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); - TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory(accessor, env); + TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory( + accessor, + environment, + clusterService + ); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -163,10 +190,11 @@ public void testExecute_withListTypeInput_successful() { List list1 = ImmutableList.of("test1", "test2", "test3"); List list2 = ImmutableList.of("test4", "test5", "test6"); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", list1); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -182,9 +210,10 @@ public void testExecute_withListTypeInput_successful() { public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", " "); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -194,9 +223,10 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() { List list1 = ImmutableList.of("", "test2", "test3"); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", list1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -206,9 +236,10 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { List list2 = ImmutableList.of(1, 2, 3); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -220,22 +251,26 @@ public void testExecute_listHasNull_throwIllegalArgumentException() { list.add(null); list.add("world"); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key2", list); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_withMapTypeInput_successful() { - Map map1 = ImmutableMap.of("test1", "test2"); - Map map2 = ImmutableMap.of("test4", "test5"); + Map map1 = new HashMap<>(); + map1.put("test1", "test2"); + Map map2 = new HashMap<>(); + map2.put("test3", "test4"); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -254,10 +289,11 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { Map map1 = ImmutableMap.of("test1", "test2"); Map map2 = ImmutableMap.of("test3", 209.3D); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -267,10 +303,11 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { Map map1 = ImmutableMap.of("test1", "test2"); Map map2 = ImmutableMap.of("test3", " "); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -279,10 +316,11 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map ret = createMaxDepthLimitExceedMap(() -> 1); Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "hello world"); sourceAndMetadata.put("key2", ret); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -290,10 +328,11 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { public void testExecute_MLClientAccessorThrowFail_handlerFailure() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); @@ -324,13 +363,14 @@ public void testExecute_hybridTypeInput_successful() throws Exception { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", map1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); IngestDocument document = processor.execute(ingestDocument); assert document.getSourceAndMetadata().containsKey("key2"); } public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", 100); sourceAndMetadata.put("key2", 100.232D); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -341,13 +381,13 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testGetType_successful() { - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); assert processor.getType().equals(TextEmbeddingProcessor.TYPE); } @@ -356,7 +396,7 @@ public void testProcessResponse_successful() throws Exception { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); @@ -369,7 +409,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); // To assert the order is not changed between config map and generated map. List configValueList = new LinkedList<>(config.values()); @@ -395,7 +435,7 @@ public void testBuildVectorOutput_withNestedMap_successful() { Map config = createNestedMapConfiguration(); IngestDocument ingestDocument = createNestedMapIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); @@ -411,7 +451,7 @@ public void testBuildVectorOutput_withNestedList_successful() { Map config = createNestedListConfiguration(); IngestDocument ingestDocument = createNestedListIngestDocument(); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); @@ -425,7 +465,7 @@ public void testBuildVectorOutput_withNestedList_Level2_successful() { Map config = createNestedList2LevelConfiguration(); IngestDocument ingestDocument = create2LevelNestedListIngestDocument(); TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + Map knnMap = textEmbeddingProcessor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map nestedLevel1 = (Map) ingestDocument.getSourceAndMetadata().get("nestedField"); @@ -440,7 +480,7 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithTargetKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); @@ -450,10 +490,32 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { assertEquals(2, ((List) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size()); } + public void test_doublyNestedList_withMapType_successful() { + Map config = createNestedListConfiguration(); + + Map toEmbeddings = new HashMap<>(); + toEmbeddings.put("textField", "text to embedding"); + List> l1List = new ArrayList<>(); + l1List.add(toEmbeddings); + List>> l2List = new ArrayList<>(); + l2List.add(l1List); + Map document = new HashMap<>(); + document.put("nestedField", l2List); + document.put(IndexFieldMapper.NAME, "my_index"); + + IngestDocument ingestDocument = new IngestDocument(document, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(handler).accept(isNull(), argumentCaptor.capture()); + assertEquals("list type field [nestedField] is nested list type, cannot process it", argumentCaptor.getValue().getMessage()); + } + public void test_batchExecute_successful() { final int docCount = 5; List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); List> modelTensorList = createMockVectorWithLength(10); doAnswer(invocation -> { @@ -476,7 +538,7 @@ public void test_batchExecute_successful() { public void test_batchExecute_exception() { final int docCount = 5; List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); - TextEmbeddingProcessor processor = createInstance(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); listener.onFailure(new RuntimeException()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 431d6a440..89a42df80 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -182,6 +182,7 @@ public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArg public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("my_text_field", "value2"); sourceAndMetadata.put("key3", "value3"); @@ -231,6 +232,7 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep public void testExecute_withListTypeInput_successful() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -263,6 +265,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { public void testExecute_MLClientAccessorThrowFail_handlerFailure() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -320,6 +323,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception { public void testExecute_whenInferencesAreEmpty_thenSuccessful() { Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("my_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); diff --git a/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java new file mode 100644 index 000000000..068edcf2f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ProcessorDocumentUtilsTests extends OpenSearchQueryTestCase { + + private ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS); + + @Mock + private Environment environment; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + } + + public void test_with_different_configurations() throws URISyntaxException, IOException { + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(clusterService.state().metadata().index(anyString()).getSettings()).thenReturn(settings); + String processorDocumentTestJson = Files.readString( + Path.of(ProcessorDocumentUtils.class.getClassLoader().getResource("util/ProcessorDocumentUtils.json").toURI()) + ); + Map processorDocumentTestMap = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + processorDocumentTestJson, + false + ); + for (Map.Entry entry : processorDocumentTestMap.entrySet()) { + String testCaseName = entry.getKey(); + Map metadata = (Map) entry.getValue(); + + Map fieldMap = (Map) metadata.get("field_map"); + Map source = (Map) metadata.get("source"); + Map expectation = (Map) metadata.get("expectation"); + try { + ProcessorDocumentUtils.validateMapTypeValue( + "field_map", + source, + fieldMap, + "test_index", + clusterService, + environment, + false + ); + } catch (Exception e) { + if (expectation != null) { + if (expectation.containsKey("type")) { + assertEquals("test case: " + testCaseName + " failed", expectation.get("type"), e.getClass().getSimpleName()); + } + if (expectation.containsKey("message")) { + assertEquals("test case: " + testCaseName + " failed", expectation.get("message"), e.getMessage()); + } + } else { + fail("test case: " + testCaseName + " failed: " + e.getMessage()); + } + } + } + } + +} diff --git a/src/test/resources/util/ProcessorDocumentUtils.json b/src/test/resources/util/ProcessorDocumentUtils.json new file mode 100644 index 000000000..e69451a4e --- /dev/null +++ b/src/test/resources/util/ProcessorDocumentUtils.json @@ -0,0 +1,161 @@ +{ + "simpleMapConfiguration": { + "field_map": { + "body": "body_embedding" + }, + "source": { + "body": "This is a test body." + } + }, + "doublyMapConfiguration": { + "field_map": { + "passage": { + "body": "body_embedding" + } + }, + "source": { + "passage": { + "body": "This is a test body." + } + } + }, + "mapWithNestedConfiguration": { + "field_map": { + "passage": { + "bodies": "bodies_embedding" + } + }, + "source": { + "passage": { + "bodies": ["test body 1", "test body 2", "test body 3"] + } + } + }, + "nestedConfiguration": { + "field_map": { + "bodies": "bodies_embedding" + }, + "source": { + "bodies": ["test body 1", "test body 2", "test body 3"] + } + }, + "nestedWithMapConfiguration": { + "field_map": { + "bodies": { + "body": "body_embedding" + } + }, + "source": { + "bodies": [ + { + "body": "This is a test body.", + "seq": 1 + }, + { + "body": "This is another test body.", + "seq": 2 + }] + } + }, + "sourceMapFieldNotMapConfiguration": { + "field_map": { + "passage": "passage_embedding" + }, + "source": { + "passage": { + "body": "This is a test body." + } + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "[passage] configuration doesn't match actual value type, configuration type is: java.lang.String, actual value type is: java.util.HashMap" + } + }, + "sourceMapTypeHasNonNestedNonStringConfiguration": { + "field_map": { + "passage": { + "body": "body_embedding" + } + }, + "source": { + "passage": { + "body": 12345 + } + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "map type field [body] is neither string nor nested type, cannot process it" + } + }, + "sourceMapTypeHasEmptyStringConfiguration": { + "field_map": { + "passage": { + "body": "body_embedding" + } + }, + "source": { + "passage": { + "body": "" + } + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "map type field [body] has empty string value, cannot process it" + } + }, + "sourceListTypeHasNullConfiguration": { + "field_map": { + "bodies": "bodies_embedding" + }, + "source": { + "bodies": ["This is a test", null, "This is another test"] + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "list type field [bodies] has null, cannot process it" + } + }, + "sourceListTypeHasEmptyConfiguration": { + "field_map": { + "bodies": "bodies_embedding" + }, + "source": { + "bodies": ["This is a test", "", "This is another test"] + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "list type field [bodies] has empty string, cannot process it" + } + }, + "sourceListTypeHasNonStringConfiguration": { + "field_map": { + "bodies": "bodies_embedding" + }, + "source": { + "bodies": ["This is a test", 1, "This is another test"] + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "list type field [bodies] has non string value, cannot process it" + } + }, + "sourceDoublyListTypeConfiguration": { + "field_map": { + "bodies": "bodies_embedding" + }, + "source": { + "bodies": [ + [ + "This is a test" + ], + [ + "This is another tetst" + ] + ] + }, + "expectation": { + "type": "IllegalArgumentException", + "message": "list type field [bodies] is nested list type, cannot process it" + } + } +}