diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fe4b36ab..ec655eae4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x) ### Features +- Support batchExecute in TextEmbeddingProcessor and SparseEncodingProcessor ([#743](https://github.com/opensearch-project/neural-search/issues/743)) ### Enhancements - Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731)) - Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733)) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index fe201abae..4956a445c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -5,20 +5,30 @@ package org.opensearch.neuralsearch.processor; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.commons.lang3.StringUtils; +import org.opensearch.common.collect.Tuple; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.env.Environment; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import com.google.common.annotations.VisibleForTesting; @@ -119,6 +129,121 @@ public void execute(IngestDocument ingestDocument, BiConsumer inferenceList, Consumer> handler, Consumer onException); + + @Override + public void batchExecute(List ingestDocumentWrappers, Consumer> handler) { + if (CollectionUtils.isEmpty(ingestDocumentWrappers)) { + handler.accept(Collections.emptyList()); + return; + } + + List dataForInferences = getDataForInference(ingestDocumentWrappers); + List inferenceList = constructInferenceTexts(dataForInferences); + if (inferenceList.isEmpty()) { + handler.accept(ingestDocumentWrappers); + return; + } + Tuple, Map> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList); + inferenceList = sortedResult.v1(); + Map originalOrder = sortedResult.v2(); + doBatchExecute(inferenceList, results -> { + int startIndex = 0; + results = restoreToOriginalOrder(results, originalOrder); + for (DataForInference dataForInference : dataForInferences) { + if (dataForInference.getIngestDocumentWrapper().getException() != null + || CollectionUtils.isEmpty(dataForInference.getInferenceList())) { + continue; + } + List inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size()); + startIndex += dataForInference.getInferenceList().size(); + setVectorFieldsToDocument( + dataForInference.getIngestDocumentWrapper().getIngestDocument(), + dataForInference.getProcessMap(), + inferenceResults + ); + } + handler.accept(ingestDocumentWrappers); + }, exception -> { + for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) { + // The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only + // set exception to IngestDocumentWrapper which doesn't have exception before. + if (ingestDocumentWrapper.getException() == null) { + ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception); + } + } + handler.accept(ingestDocumentWrappers); + }); + } + + private Tuple, Map> sortByLengthAndReturnOriginalOrder(List inferenceList) { + List> docsWithIndex = new ArrayList<>(); + for (int i = 0; i < inferenceList.size(); ++i) { + docsWithIndex.add(Tuple.tuple(i, inferenceList.get(i))); + } + docsWithIndex.sort(Comparator.comparingInt(t -> t.v2().length())); + List sortedInferenceList = docsWithIndex.stream().map(Tuple::v2).collect(Collectors.toList()); + Map originalOrderMap = new HashMap<>(); + for (int i = 0; i < docsWithIndex.size(); ++i) { + originalOrderMap.put(i, docsWithIndex.get(i).v1()); + } + return Tuple.tuple(sortedInferenceList, originalOrderMap); + } + + private List restoreToOriginalOrder(List results, Map originalOrder) { + List sortedResults = Arrays.asList(results.toArray()); + for (int i = 0; i < results.size(); ++i) { + if (!originalOrder.containsKey(i)) continue; + int oldIndex = originalOrder.get(i); + sortedResults.set(oldIndex, results.get(i)); + } + return sortedResults; + } + + private List constructInferenceTexts(List dataForInferences) { + List inferenceTexts = new ArrayList<>(); + for (DataForInference dataForInference : dataForInferences) { + if (dataForInference.getIngestDocumentWrapper().getException() != null + || CollectionUtils.isEmpty(dataForInference.getInferenceList())) { + continue; + } + inferenceTexts.addAll(dataForInference.getInferenceList()); + } + return inferenceTexts; + } + + private List getDataForInference(List ingestDocumentWrappers) { + List dataForInferences = new ArrayList<>(); + for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) { + Map processMap = null; + List inferenceList = null; + try { + validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument()); + processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument()); + inferenceList = createInferenceList(processMap); + } catch (Exception e) { + ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e); + } finally { + dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList)); + } + } + return dataForInferences; + } + + @Getter + @AllArgsConstructor + private static class DataForInference { + private final IngestDocumentWrapper ingestDocumentWrapper; + private final Map processMap; + private final List inferenceList; + } + @SuppressWarnings({ "unchecked" }) private List createInferenceList(Map knnKeyMap) { List texts = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 8acf95bf7..9e2336cf6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.Consumer; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; @@ -49,4 +50,13 @@ public void doExecute( handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } + + @Override + public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { + mlCommonsClientAccessor.inferenceSentencesWithMapResult( + this.modelId, + inferenceList, + ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) + ); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c1b8f92a6..7e765624e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.Consumer; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; @@ -48,4 +49,9 @@ public void doExecute( handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } + + @Override + public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException)); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java new file mode 100644 index 000000000..05a327b82 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import com.google.common.collect.ImmutableList; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.IngestDocumentWrapper; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class InferenceProcessorTestCase extends OpenSearchTestCase { + + protected List createIngestDocumentWrappers(int count) { + List wrapperList = new ArrayList<>(); + for (int i = 0; i < count; ++i) { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + wrapperList.add(new IngestDocumentWrapper(i, new IngestDocument(sourceAndMetadata, new HashMap<>()), null)); + } + return wrapperList; + } + + protected List> createMockVectorWithLength(int size) { + float suffix = .234f; + List> result = new ArrayList<>(); + for (int i = 0; i < size * 2;) { + List number = new ArrayList<>(); + number.add(i++ + suffix); + number.add(i++ + suffix); + result.add(number); + } + return result; + } + + protected List> createMockVectorResult() { + List> modelTensorList = new ArrayList<>(); + List number1 = ImmutableList.of(1.234f, 2.354f); + List number2 = ImmutableList.of(3.234f, 4.354f); + List number3 = ImmutableList.of(5.234f, 6.354f); + List number4 = ImmutableList.of(7.234f, 8.354f); + List number5 = ImmutableList.of(9.234f, 10.354f); + List number6 = ImmutableList.of(11.234f, 12.354f); + List number7 = ImmutableList.of(13.234f, 14.354f); + modelTensorList.add(number1); + modelTensorList.add(number2); + modelTensorList.add(number3); + modelTensorList.add(number4); + modelTensorList.add(number5); + modelTensorList.add(number6); + modelTensorList.add(number7); + return modelTensorList; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java new file mode 100644 index 000000000..43c2ba1fb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -0,0 +1,202 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.IngestDocumentWrapper; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class InferenceProcessorTests extends InferenceProcessorTestCase { + private MLCommonsClientAccessor clientAccessor; + private Environment environment; + + private static final String TAG = "tag"; + private static final String TYPE = "type"; + private static final String DESCRIPTION = "description"; + private static final String MAP_KEY = "map_key"; + private static final String MODEL_ID = "model_id"; + private static final Map FIELD_MAP = Map.of("key1", "embedding_key1", "key2", "embedding_key2"); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + clientAccessor = mock(MLCommonsClientAccessor.class); + environment = mock(Environment.class); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(environment.settings()).thenReturn(settings); + } + + public void test_batchExecute_emptyInput() { + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(Collections.emptyList(), resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertTrue(captor.getValue().isEmpty()); + verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + } + + public void test_batchExecute_allFailedValidation() { + final int docCount = 2; + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null); + List wrapperList = createIngestDocumentWrappers(docCount); + wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); + wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(wrapperList, resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertEquals(docCount, captor.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertNotNull(captor.getValue().get(i).getException()); + assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); + } + verify(clientAccessor, never()).inferenceSentences(anyString(), anyList(), any()); + } + + public void test_batchExecute_partialFailedValidation() { + final int docCount = 2; + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null); + List wrapperList = createIngestDocumentWrappers(docCount); + wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); + wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(wrapperList, resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertEquals(docCount, captor.getValue().size()); + assertNotNull(captor.getValue().get(0).getException()); + assertNull(captor.getValue().get(1).getException()); + for (int i = 0; i < docCount; ++i) { + assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); + } + ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); + verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); + assertEquals(2, inferenceTextCaptor.getValue().size()); + } + + public void test_batchExecute_happyCase() { + final int docCount = 2; + List> inferenceResults = createMockVectorWithLength(6); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null); + List wrapperList = createIngestDocumentWrappers(docCount); + wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2")); + wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(wrapperList, resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertEquals(docCount, captor.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertNull(captor.getValue().get(i).getException()); + assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); + } + ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); + verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); + assertEquals(4, inferenceTextCaptor.getValue().size()); + } + + public void test_batchExecute_sort() { + final int docCount = 2; + List> inferenceResults = createMockVectorWithLength(100); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null); + List wrapperList = createIngestDocumentWrappers(docCount); + wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("aaaaa", "bbb")); + wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("cc", "ddd")); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(wrapperList, resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertEquals(docCount, captor.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertNull(captor.getValue().get(i).getException()); + assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); + } + ArgumentCaptor> inferenceTextCaptor = ArgumentCaptor.forClass(List.class); + verify(clientAccessor).inferenceSentences(anyString(), inferenceTextCaptor.capture(), any()); + assertEquals(4, inferenceTextCaptor.getValue().size()); + assertEquals(Arrays.asList("cc", "bbb", "ddd", "aaaaa"), inferenceTextCaptor.getValue()); + + List doc1Embeddings = (List) (captor.getValue().get(0).getIngestDocument().getFieldValue("embedding_key1", List.class)); + List doc2Embeddings = (List) (captor.getValue().get(1).getIngestDocument().getFieldValue("embedding_key1", List.class)); + assertEquals(2, doc1Embeddings.size()); + assertEquals(2, doc2Embeddings.size()); + // inferenceResults are results for sorted-by-length array ("cc", "bbb", "ddd", "aaaaa") + assertEquals(inferenceResults.get(3), ((Map) doc1Embeddings.get(0)).get("map_key")); + assertEquals(inferenceResults.get(1), ((Map) doc1Embeddings.get(1)).get("map_key")); + assertEquals(inferenceResults.get(0), ((Map) doc2Embeddings.get(0)).get("map_key")); + assertEquals(inferenceResults.get(2), ((Map) doc2Embeddings.get(1)).get("map_key")); + } + + public void test_doBatchExecute_exception() { + final int docCount = 2; + List> inferenceResults = createMockVectorWithLength(6); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, new RuntimeException()); + List wrapperList = createIngestDocumentWrappers(docCount); + wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2")); + wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(wrapperList, resultHandler); + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(captor.capture()); + assertEquals(docCount, captor.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertNotNull(captor.getValue().get(i).getException()); + assertEquals(wrapperList.get(i).getIngestDocument(), captor.getValue().get(i).getIngestDocument()); + } + verify(clientAccessor).inferenceSentences(anyString(), anyList(), any()); + } + + private class TestInferenceProcessor extends InferenceProcessor { + List vectors; + Exception exception; + + public TestInferenceProcessor(List vectors, Exception exception) { + super(TAG, DESCRIPTION, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment); + this.vectors = vectors; + this.exception = exception; + } + + @Override + public void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ) {} + + @Override + void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { + // use to verify if doBatchExecute is called from InferenceProcessor + clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {})); + if (this.exception != null) { + onException.accept(this.exception); + } else { + handler.accept(this.vectors); + } + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 815ea851b..5b85ec923 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -24,9 +24,11 @@ import java.util.List; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.stream.IntStream; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -34,17 +36,17 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; -import org.opensearch.test.OpenSearchTestCase; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; -public class SparseEncodingProcessorTests extends OpenSearchTestCase { +public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; @@ -170,6 +172,49 @@ public void testExecute_withMapTypeInput_successful() { } + public void test_batchExecute_successful() { + final int docCount = 5; + List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); + SparseEncodingProcessor processor = createInstance(); + List> dataAsMapList = createMockMapResult(10); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(ingestDocumentWrappers, resultHandler); + ArgumentCaptor> resultCallback = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCallback.capture()); + assertEquals(docCount, resultCallback.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertEquals(ingestDocumentWrappers.get(i).getIngestDocument(), resultCallback.getValue().get(i).getIngestDocument()); + assertNull(resultCallback.getValue().get(i).getException()); + } + } + + public void test_batchExecute_exception() { + final int docCount = 5; + List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); + SparseEncodingProcessor processor = createInstance(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(ingestDocumentWrappers, resultHandler); + ArgumentCaptor> resultCallback = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCallback.capture()); + assertEquals(docCount, resultCallback.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertEquals(ingestDocumentWrappers.get(i).getIngestDocument(), resultCallback.getValue().get(i).getIngestDocument()); + assertNotNull(resultCallback.getValue().get(i).getException()); + } + } + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 44fda54c2..f963c48fc 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -4,10 +4,15 @@ */ package org.opensearch.neuralsearch.processor; +import java.io.IOException; +import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import org.apache.commons.lang3.StringUtils; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; @@ -24,6 +29,13 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private static final String INDEX_NAME = "text_embedding_index"; private static final String PIPELINE_NAME = "pipeline-hybrid"; + private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI())); + private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); + private final String BULK_ITEM_TEMPLATE = Files.readString( + Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI()) + ); + + public TextEmbeddingProcessorIT() throws IOException, URISyntaxException {} @Before public void setUp() throws Exception { @@ -38,13 +50,33 @@ public void testTextEmbeddingProcessor() throws Exception { loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); createTextEmbeddingIndex(); - ingestDocument(); + ingestDocument(INGEST_DOC1, null); assertEquals(1, getDocCount(INDEX_NAME)); } finally { wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); } } + public void testTextEmbeddingProcessor_batch() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createTextEmbeddingIndex(); + ingestBatchDocumentWithBulk("batch_"); + assertEquals(2, getDocCount(INDEX_NAME)); + + ingestDocument(INGEST_DOC1, "1"); + ingestDocument(INGEST_DOC2, "2"); + + assertEquals(getDocById(INDEX_NAME, "1").get("_source"), getDocById(INDEX_NAME, "batch_1").get("_source")); + assertEquals(getDocById(INDEX_NAME, "2").get("_source"), getDocById(INDEX_NAME, "batch_2").get("_source")); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + private String uploadTextEmbeddingModel() throws Exception { String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); return registerModelGroupAndUploadModel(requestBody); @@ -58,34 +90,19 @@ private void createTextEmbeddingIndex() throws Exception { ); } - private void ingestDocument() throws Exception { - String ingestDocument = "{\n" - + " \"title\": \"This is a good day\",\n" - + " \"description\": \"daily logging\",\n" - + " \"favor_list\": [\n" - + " \"test\",\n" - + " \"hello\",\n" - + " \"mock\"\n" - + " ],\n" - + " \"favorites\": {\n" - + " \"game\": \"overwatch\",\n" - + " \"movie\": null\n" - + " },\n" - + " \"nested_passages\": [\n" - + " {\n" - + " \"text\": \"hello\"\n" - + " },\n" - + " {\n" - + " \"text\": \"world\"\n" - + " }\n" - + " ]\n" - + "}\n"; + private void ingestDocument(String doc, String id) throws Exception { + String endpoint; + if (StringUtils.isEmpty(id)) { + endpoint = INDEX_NAME + "/_doc?refresh"; + } else { + endpoint = INDEX_NAME + "/_doc/" + id + "?refresh"; + } Response response = makeRequest( client(), "POST", - INDEX_NAME + "/_doc?refresh", + endpoint, null, - toHttpEntity(ingestDocument), + toHttpEntity(doc), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Map map = XContentHelper.convertToMap( @@ -96,4 +113,37 @@ private void ingestDocument() throws Exception { assertEquals("created", map.get("result")); } + private void ingestBatchDocumentWithBulk(String idPrefix) throws Exception { + String doc1 = INGEST_DOC1.replace("\n", ""); + String doc2 = INGEST_DOC2.replace("\n", ""); + final String id1 = idPrefix + "1"; + final String id2 = idPrefix + "2"; + String item1 = BULK_ITEM_TEMPLATE.replace("{{index}}", INDEX_NAME) + .replace("{{id}}", id1) + .replace("{{doc}}", doc1) + .replace("{{comma}}", ","); + String item2 = BULK_ITEM_TEMPLATE.replace("{{index}}", INDEX_NAME) + .replace("{{id}}", id2) + .replace("{{doc}}", doc2) + .replace("{{comma}}", "\n"); + final String payload = item1 + item2; + Map params = new HashMap<>(); + params.put("refresh", "true"); + params.put("batch_size", "2"); + Response response = makeRequest( + client(), + "POST", + "_bulk", + params, + toHttpEntity(payload), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals(false, map.get("errors")); + assertEquals(2, ((List) map.get("items")).size()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 60408d820..752615057 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -22,9 +22,11 @@ import java.util.Map; import java.util.Arrays; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -33,17 +35,17 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; -import org.opensearch.test.OpenSearchTestCase; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; -public class TextEmbeddingProcessorTests extends OpenSearchTestCase { +public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; @@ -64,7 +66,7 @@ public void setup() { } @SneakyThrows - private TextEmbeddingProcessor createInstance(List> vector) { + private TextEmbeddingProcessor createInstance() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -105,7 +107,7 @@ public void testExecute_successful() { sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -164,7 +166,7 @@ public void testExecute_withListTypeInput_successful() { sourceAndMetadata.put("key1", list1); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6)); + TextEmbeddingProcessor processor = createInstance(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -182,7 +184,7 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", " "); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -194,7 +196,7 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", list1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -206,7 +208,7 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -220,7 +222,7 @@ public void testExecute_listHasNull_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", list); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -233,7 +235,7 @@ public void testExecute_withMapTypeInput_successful() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -255,7 +257,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -268,7 +270,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -280,7 +282,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { sourceAndMetadata.put("key1", "hello world"); sourceAndMetadata.put("key2", ret); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -291,7 +293,7 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); @@ -322,7 +324,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", map1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); IngestDocument document = processor.execute(ingestDocument); assert document.getSourceAndMetadata().containsKey("key2"); } @@ -339,13 +341,13 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testGetType_successful() { - TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextEmbeddingProcessor processor = createInstance(); assert processor.getType().equals(TextEmbeddingProcessor.TYPE); } @@ -448,35 +450,48 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { assertEquals(2, ((List) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size()); } - private List> createMockVectorResult() { - List> modelTensorList = new ArrayList<>(); - List number1 = ImmutableList.of(1.234f, 2.354f); - List number2 = ImmutableList.of(3.234f, 4.354f); - List number3 = ImmutableList.of(5.234f, 6.354f); - List number4 = ImmutableList.of(7.234f, 8.354f); - List number5 = ImmutableList.of(9.234f, 10.354f); - List number6 = ImmutableList.of(11.234f, 12.354f); - List number7 = ImmutableList.of(13.234f, 14.354f); - modelTensorList.add(number1); - modelTensorList.add(number2); - modelTensorList.add(number3); - modelTensorList.add(number4); - modelTensorList.add(number5); - modelTensorList.add(number6); - modelTensorList.add(number7); - return modelTensorList; - } - - private List> createMockVectorWithLength(int size) { - float suffix = .234f; - List> result = new ArrayList<>(); - for (int i = 0; i < size * 2;) { - List number = new ArrayList<>(); - number.add(i++ + suffix); - number.add(i++ + suffix); - result.add(number); + public void test_batchExecute_successful() { + final int docCount = 5; + List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); + TextEmbeddingProcessor processor = createInstance(); + + List> modelTensorList = createMockVectorWithLength(10); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(ingestDocumentWrappers, resultHandler); + ArgumentCaptor> resultCallback = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCallback.capture()); + assertEquals(docCount, resultCallback.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertEquals(ingestDocumentWrappers.get(i).getIngestDocument(), resultCallback.getValue().get(i).getIngestDocument()); + assertNull(resultCallback.getValue().get(i).getException()); + } + } + + public void test_batchExecute_exception() { + final int docCount = 5; + List ingestDocumentWrappers = createIngestDocumentWrappers(docCount); + TextEmbeddingProcessor processor = createInstance(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + Consumer resultHandler = mock(Consumer.class); + processor.batchExecute(ingestDocumentWrappers, resultHandler); + ArgumentCaptor> resultCallback = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCallback.capture()); + assertEquals(docCount, resultCallback.getValue().size()); + for (int i = 0; i < docCount; ++i) { + assertEquals(ingestDocumentWrappers.get(i).getIngestDocument(), resultCallback.getValue().get(i).getIngestDocument()); + assertNotNull(resultCallback.getValue().get(i).getException()); } - return result; } @SneakyThrows diff --git a/src/test/resources/processor/bulk_item_template.json b/src/test/resources/processor/bulk_item_template.json new file mode 100644 index 000000000..79881b630 --- /dev/null +++ b/src/test/resources/processor/bulk_item_template.json @@ -0,0 +1,2 @@ +{ "index": { "_index": "{{index}}", "_id": "{{id}}" } }, +{{doc}}{{comma}} diff --git a/src/test/resources/processor/ingest_doc1.json b/src/test/resources/processor/ingest_doc1.json new file mode 100644 index 000000000..e3302c75a --- /dev/null +++ b/src/test/resources/processor/ingest_doc1.json @@ -0,0 +1,21 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": [ + { + "text": "hello" + }, + { + "text": "world" + } + ] +} diff --git a/src/test/resources/processor/ingest_doc2.json b/src/test/resources/processor/ingest_doc2.json new file mode 100644 index 000000000..400f9027a --- /dev/null +++ b/src/test/resources/processor/ingest_doc2.json @@ -0,0 +1,19 @@ +{ + "title": "this is a second doc", + "description": "the description is not very long", + "favor_list": [ + "favor" + ], + "favorites": { + "game": "golden state", + "movie": null + }, + "nested_passages": [ + { + "text": "apple" + }, + { + "text": "banana" + } + ] +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index baecf2932..516e9fe8c 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -403,6 +403,21 @@ protected int getDocCount(final String indexName) { return (Integer) responseMap.get("count"); } + /** + * Get one doc by its id + * @param indexName index name + * @param id doc id + * @return map of the doc data + */ + @SneakyThrows + protected Map getDocById(final String indexName, final String id) { + Request request = new Request("GET", "/" + indexName + "/_doc/" + id); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + String responseBody = EntityUtils.toString(response.getEntity()); + return createParser(XContentType.JSON.xContent(), responseBody).map(); + } + /** * Execute a search request initialized from a neural query builder *