diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 5bb54cc2c..1c09f5996 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -226,7 +226,7 @@ private void retryableInferenceSentencesWithSingleVectorResult( MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List vector = buildSingleVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector); + log.debug("Inference Response for input sentence is : {} ", vector); listener.onResponse(vector); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index cf1a2f9bd..8672c6142 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -110,7 +110,7 @@ public Map getProcessors(Processor.Parameters paramet SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env), TextImageEmbeddingProcessor.TYPE, - new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env) + new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index 70ddc0d60..a0d9606e9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -18,8 +18,11 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; @@ -50,6 +53,7 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { private final MLCommonsClientAccessor mlCommonsClientAccessor; private final Environment environment; + private final ClusterService clusterService; public TextImageEmbeddingProcessor( final String tag, @@ -58,7 +62,8 @@ public TextImageEmbeddingProcessor( final String embedding, final Map fieldMap, final MLCommonsClientAccessor clientAccessor, - final Environment environment + final Environment environment, + final ClusterService clusterService ) { super(tag, description); if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); @@ -69,6 +74,7 @@ public TextImageEmbeddingProcessor( this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; + this.clusterService = clusterService; } private void validateEmbeddingConfiguration(final Map fieldMap) { @@ -176,7 +182,8 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { } Class sourceValueClass = sourceValue.getClass(); if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1); + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1, indexName); } else if (!String.class.isAssignableFrom(sourceValueClass)) { throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { @@ -187,9 +194,15 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { } @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final Supplier maxDepthSupplier) { + private void validateNestedTypeValue( + final String sourceKey, + final Object sourceValue, + final Supplier maxDepthSupplier, + final String indexName + ) { int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + Settings indexSettings = clusterService.state().metadata().index(indexName).getSettings(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings)) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { validateListTypeValue(sourceKey, (List) sourceValue); @@ -197,7 +210,7 @@ private void validateNestedTypeValue(final String sourceKey, final Object source ((Map) sourceValue).values() .stream() .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1, indexName)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java index df13c523b..c18ec6fb3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -15,6 +15,9 @@ import java.util.Map; +import lombok.AllArgsConstructor; + +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; @@ -22,16 +25,12 @@ /** * Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. */ +@AllArgsConstructor public class TextImageEmbeddingProcessorFactory implements Factory { private final MLCommonsClientAccessor clientAccessor; - private final Environment environment; - - public TextImageEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { - this.clientAccessor = clientAccessor; - this.environment = environment; - } + private final ClusterService clusterService; @Override public TextImageEmbeddingProcessor create( @@ -43,6 +42,15 @@ public TextImageEmbeddingProcessor create( String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new TextImageEmbeddingProcessor(processorTag, description, modelId, embedding, filedMap, clientAccessor, environment); + return new TextImageEmbeddingProcessor( + processorTag, + description, + modelId, + embedding, + filedMap, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 8cae15678..69791681e 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -11,6 +11,7 @@ import java.util.Map; import java.util.Optional; +import org.opensearch.ingest.IngestService; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; @@ -56,7 +57,17 @@ public void testQueryPhaseSearcher() { public void testProcessors() { NeuralSearch plugin = new NeuralSearch(); - Processor.Parameters processorParams = mock(Processor.Parameters.class); + Processor.Parameters processorParams = new Processor.Parameters( + null, + null, + null, + null, + null, + null, + mock(IngestService.class), + null, + null + ); Map processors = plugin.getProcessors(processorParams); assertNotNull(processors); assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index c0cab4422..bae336d4a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -32,9 +32,14 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -48,9 +53,16 @@ public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; - @Mock private Environment env; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private IndexMetadata indexMetadata; @InjectMocks private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory; @@ -62,6 +74,10 @@ public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); when(env.settings()).thenReturn(settings); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.index(anyString())).thenReturn(indexMetadata); + when(indexMetadata.getSettings()).thenReturn(settings); } @SneakyThrows @@ -98,7 +114,16 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t // create with null type mapping IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> new TextImageEmbeddingProcessor(PROCESSOR_TAG, DESCRIPTION, modelId, embeddingField, null, mlCommonsClientAccessor, env) + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + null, + mlCommonsClientAccessor, + env, + clusterService + ) ); assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); @@ -112,7 +137,8 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t embeddingField, Map.of("", "my_field"), mlCommonsClientAccessor, - env + env, + clusterService ) ); assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); @@ -131,7 +157,8 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t embeddingField, typeMapping, mlCommonsClientAccessor, - env + env, + clusterService ) ); assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); @@ -183,7 +210,11 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); - TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(accessor, env); + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + accessor, + env, + clusterService + ); Map config = new HashMap<>(); config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -223,6 +254,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "hello world"); sourceAndMetadata.put("my_text_field", ret); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); @@ -254,6 +286,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("my_text_field", map2); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); @@ -267,6 +300,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("my_text_field", map2); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java index 39d1f14de..cbf53b8fc 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -19,6 +19,7 @@ import lombok.SneakyThrows; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; @@ -30,7 +31,8 @@ public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( mock(MLCommonsClientAccessor.class), - mock(Environment.class) + mock(Environment.class), + mock(ClusterService.class) ); final Map processorFactories = new HashMap<>(); @@ -55,7 +57,8 @@ public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( mock(MLCommonsClientAccessor.class), - mock(Environment.class) + mock(Environment.class), + mock(ClusterService.class) ); final Map processorFactories = new HashMap<>(); @@ -88,7 +91,8 @@ public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( mock(MLCommonsClientAccessor.class), - mock(Environment.class) + mock(Environment.class), + mock(ClusterService.class) ); final Map processorFactories = new HashMap<>();