From a90784c761da86b4ab2169ace9923ddff43567d9 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 16 Jan 2024 11:31:12 -0800 Subject: [PATCH] Adding support for generic re-ranker interface and opensearch ml re-ranker for improving search relavancy. (#494) * Add rerank processor interfaces Signed-off-by: HenryL27 * add cross-encoder specific logic and factory Signed-off-by: HenryL27 * add unittests Signed-off-by: HenryL27 * add integration test Signed-off-by: HenryL27 * use string.format() instead of concatenation Signed-off-by: HenryL27 * rename generateScoringContext to generateRerankingContext Signed-off-by: HenryL27 * add name change in test too. whoops Signed-off-by: HenryL27 * start refactoring with contextSaourceFetchers Signed-off-by: HenryL27 * refactor to use contextSourceFetchers to get context Signed-off-by: HenryL27 * rename CrossEncoder to TextSimilarity Signed-off-by: HenryL27 * add query_context layer to search ext Signed-off-by: HenryL27 * add javadocs Signed-off-by: HenryL27 * update to new asyncProcessResponse api Signed-off-by: HenryL27 * rename reranktype to ML_OPENSEARCH Signed-off-by: HenryL27 * improve error messages for bad rerank type config Signed-off-by: HenryL27 * simplify configuration/factory logic Signed-off-by: HenryL27 * improve handling for non-flat-string context fields Signed-off-by: HenryL27 * rename TextSimilarity files to MLOpenSearch files Signed-off-by: HenryL27 * apply spotless after rebase Signed-off-by: HenryL27 * update changelog Signed-off-by: HenryL27 * after rebase Signed-off-by: HenryL27 * Address pr comments and fix XContent in search ext Signed-off-by: HenryL27 * move contextSourceFetchers to their own subdirectory Signed-off-by: HenryL27 * Apply suggestions from code review Co-authored-by: Martin Gaievski Signed-off-by: HenryL27 * CR changes Signed-off-by: HenryL27 * finish CR comments and fix broken unittest Signed-off-by: HenryL27 * fix unittest names Signed-off-by: HenryL27 --------- Signed-off-by: HenryL27 Co-authored-by: Martin Gaievski --- CHANGELOG.md | 1 + .../ml/MLCommonsClientAccessor.java | 45 +++ .../neuralsearch/plugin/NeuralSearch.java | 22 ++ .../factory/RerankProcessorFactory.java | 132 +++++++ .../rerank/MLOpenSearchRerankProcessor.java | 83 +++++ .../processor/rerank/RerankProcessor.java | 105 ++++++ .../processor/rerank/RerankType.java | 53 +++ .../rerank/RescoringRerankProcessor.java | 119 ++++++ .../rerank/context/ContextSourceFetcher.java | 40 ++ .../context/DocumentContextSourceFetcher.java | 101 ++++++ .../context/QueryContextSourceFetcher.java | 110 ++++++ .../query/ext/RerankSearchExtBuilder.java | 113 ++++++ .../ml/MLCommonsClientAccessorTests.java | 82 +++++ .../factory/RerankProcessorFactoryTests.java | 190 ++++++++++ .../rerank/MLOpenSearchRerankProcessorIT.java | 138 +++++++ .../MLOpenSearchRerankProcessorTests.java | 342 ++++++++++++++++++ .../ext/RerankSearchExtBuilderTests.java | 109 ++++++ ...rankMLOpenSearchPipelineConfiguration.json | 15 + .../UploadTextSimilarityModelRequestBody.json | 16 + .../neuralsearch/BaseNeuralSearchIT.java | 17 + 20 files changed, 1833 insertions(+) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java create mode 100644 src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json create mode 100644 src/test/resources/processor/UploadTextSimilarityModelRequestBody.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 99335681d..85524f4bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.11...2.x) ### Features +- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494)) ### Enhancements ### Bug Fixes - Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524)) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index e12211d28..f9ddf73a9 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -19,6 +19,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -132,6 +133,25 @@ public void inferenceSentences( retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); } + /** + * Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the + * {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing + * the similarity scores of the texts w.r.t. the query text, in the order of the input texts. + * + * @param modelId {@link String} ML-Commons Model Id + * @param queryText {@link String} The query to compare all the inputText to + * @param inputText {@link List} of {@link String} The texts to compare to the query + * @param listener {@link ActionListener} receives the result of the inference + */ + public void inferenceSimilarity( + @NonNull final String modelId, + @NonNull final String queryText, + @NonNull final List inputText, + @NonNull final ActionListener> listener + ) { + retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener); + } + private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, @@ -173,12 +193,37 @@ private void retryableInferenceSentencesWithVectorResult( })); } + private void retryableInferenceSimilarityWithVectorResult( + final String modelId, + final String queryText, + final List inputText, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLTextPairsInput(queryText, inputText); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); + listener.onResponse(scores); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener); + } else { + listener.onFailure(e); + } + })); + } + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } + private MLInput createMLTextPairsInput(final String query, final List inputText) { + final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText); + return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); + } + private List> buildVectorFromResponse(MLOutput mlOutput) { final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index aacb8d2e6..a77118f43 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -34,14 +34,17 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; @@ -54,6 +57,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -150,4 +154,22 @@ public Map> getResponseProcessors( + Parameters parameters + ) { + return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor)); + } + + @Override + public List> getSearchExts() { + return List.of( + new SearchExtSpec<>( + RerankSearchExtBuilder.PARAM_FIELD_NAME, + in -> new RerankSearchExtBuilder(in), + parser -> RerankSearchExtBuilder.parse(parser) + ) + ); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java new file mode 100644 index 000000000..b02666855 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; + +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import com.google.common.collect.Sets; + +import lombok.AllArgsConstructor; + +/** + * Factory for rerank processors. Must: + * - Instantiate the right kind of rerank processor + * - Instantiate the appropriate context source fetchers + */ +@AllArgsConstructor +public class RerankProcessorFactory implements Processor.Factory { + + public static final String RERANK_PROCESSOR_TYPE = "rerank"; + public static final String CONTEXT_CONFIG_FIELD = "context"; + + private final MLCommonsClientAccessor clientAccessor; + + @Override + public SearchResponseProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) { + RerankType type = findRerankType(config); + boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); + List contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag); + switch (type) { + case ML_OPENSEARCH: + Map rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); + String modelId = ConfigurationUtils.readStringProperty( + RERANK_PROCESSOR_TYPE, + tag, + rerankerConfig, + MLOpenSearchRerankProcessor.MODEL_ID_FIELD + ); + return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); + } + } + + private RerankType findRerankType(final Map config) throws IllegalArgumentException { + // Set of rerank type labels in the config + Set rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet()); + // A rerank type must be provided + if (rerankTypes.size() == 0) { + StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]"); + for (RerankType t : RerankType.values()) { + msgBuilder.add(t.getLabel()); + } + throw new IllegalArgumentException(msgBuilder.toString()); + } + // Only one rerank type may be provided + if (rerankTypes.size() > 1) { + StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted."); + rerankTypes.forEach(rt -> msgBuilder.add(rt)); + throw new IllegalArgumentException(msgBuilder.toString()); + } + return RerankType.from(rerankTypes.iterator().next()); + } + + /** + * Factory class for context fetchers. Constructs a list of context fetchers + * specified in the pipeline config (and maybe the query context fetcher) + */ + private static class ContextFetcherFactory { + + /** + * Map rerank types to whether they should include the query context source fetcher + * @param type the constructing RerankType + * @return does this RerankType depend on the QueryContextSourceFetcher? + */ + public static boolean shouldIncludeQueryContextFetcher(RerankType type) { + return type == RerankType.ML_OPENSEARCH; + } + + /** + * Create necessary queryContextFetchers for this processor + * @param config processor config object. Look for "context" field to find fetchers + * @param includeQueryContextFetcher should I include the queryContextFetcher? + * @return list of contextFetchers for the processor to use + */ + public static List createFetchers( + Map config, + boolean includeQueryContextFetcher, + String tag + ) { + Map contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD); + List fetchers = new ArrayList<>(); + for (String key : contextConfig.keySet()) { + Object cfg = contextConfig.get(key); + switch (key) { + case DocumentContextSourceFetcher.NAME: + fetchers.add(DocumentContextSourceFetcher.create(cfg)); + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); + } + } + if (includeQueryContextFetcher) { + fetchers.add(new QueryContextSourceFetcher()); + } + return fetchers; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java new file mode 100644 index 000000000..d8d9e8ec3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; + +/** + * Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore + */ +public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor { + + public static final String MODEL_ID_FIELD = "model_id"; + + protected final String modelId; + + protected final MLCommonsClientAccessor mlCommonsClientAccessor; + + /** + * Constructor + * @param description + * @param tag + * @param ignoreFailure + * @param modelId id of TEXT_SIMILARITY model + * @param contextSourceFetchers + * @param mlCommonsClientAccessor + */ + public MLOpenSearchRerankProcessor( + final String description, + final String tag, + final boolean ignoreFailure, + final String modelId, + final List contextSourceFetchers, + final MLCommonsClientAccessor mlCommonsClientAccessor + ) { + super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers); + this.modelId = modelId; + this.mlCommonsClientAccessor = mlCommonsClientAccessor; + } + + @Override + public void rescoreSearchResponse( + final SearchResponse response, + final Map rerankingContext, + final ActionListener> listener + ) { + Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); + if (!(ctxObj instanceof List)) { + listener.onFailure( + new IllegalStateException( + String.format( + Locale.ROOT, + "No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + DocumentContextSourceFetcher.NAME + ) + ) + ); + return; + } + List ctxList = (List) ctxObj; + List contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); + mlCommonsClientAccessor.inferenceSimilarity( + modelId, + (String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), + contexts, + listener + ); + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java new file mode 100644 index 000000000..93a2c8416 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Abstract base class for reranking processors + */ +@AllArgsConstructor +public abstract class RerankProcessor implements SearchResponseProcessor { + + public static final String TYPE = "rerank"; + + protected final RerankType subType; + @Getter + private final String description; + @Getter + private final String tag; + @Getter + private final boolean ignoreFailure; + protected List contextSourceFetchers; + + /** + * Generate the information that this processor needs in order to rerank. + * Concurrently hit all contextSourceFetchers + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + public void generateRerankingContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ) { + Map overallContext = new ConcurrentHashMap<>(); + AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size()); + for (ContextSourceFetcher csf : contextSourceFetchers) { + csf.fetchContext(searchRequest, searchResponse, ActionListener.wrap(context -> { + overallContext.putAll(context); + if (successfulContexts.decrementAndGet() == 0) { + listener.onResponse(overallContext); + } + }, e -> { listener.onFailure(e); })); + } + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Given the scoring context generated by the processor and the search results, + * rerank the search results. Do so asynchronously. + * @param searchResponse the search results to rerank + * @param rerankingContext the information this processor needs in order to rerank + * @param listener be async + */ + public abstract void rerank( + final SearchResponse searchResponse, + final Map rerankingContext, + final ActionListener listener + ); + + @Override + public SearchResponse processResponse(final SearchRequest request, final SearchResponse response) throws Exception { + throw new UnsupportedOperationException("Use asyncProcessResponse unless you can guarantee to not deadlock yourself"); + } + + @Override + public void processResponseAsync( + final SearchRequest request, + final SearchResponse response, + final PipelineProcessingContext ctx, + final ActionListener responseListener + ) { + try { + generateRerankingContext( + request, + response, + ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { + responseListener.onFailure(e); + }) + ); + } catch (Exception e) { + responseListener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java new file mode 100644 index 000000000..2063242dd --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import lombok.Getter; + +/** + * enum for distinguishing various reranking methods + */ +public enum RerankType { + + ML_OPENSEARCH("ml_opensearch"); + + @Getter + private final String label; + + private RerankType(String label) { + this.label = label; + } + + private static final Map LABEL_MAP; + static { + Map labelMap = new HashMap<>(); + for (RerankType type : RerankType.values()) { + labelMap.put(type.getLabel(), type); + } + LABEL_MAP = Collections.unmodifiableMap(labelMap); + } + + /** + * Construct a RerankType from the label + * @param label label of a RerankType + * @return RerankType represented by the label + */ + public static RerankType from(final String label) { + RerankType ans = LABEL_MAP.get(label); + if (ans == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); + } + return ans; + } + + public static Map labelMap() { + return LABEL_MAP; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java new file mode 100644 index 000000000..27ccf51e6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.core.action.ActionListener; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.profile.SearchProfileShardResults; + +/** + * RerankProcessor that rescores all the documents and re-sorts them using the new scores + */ +public abstract class RescoringRerankProcessor extends RerankProcessor { + + /** + * Constructor. pass through to RerankProcessor constructor. + * @param type + * @param description + * @param tag + * @param ignoreFailure + * @param contextSourceFetchers + */ + public RescoringRerankProcessor( + final RerankType type, + final String description, + final String tag, + final boolean ignoreFailure, + final List contextSourceFetchers + ) { + super(type, description, tag, ignoreFailure, contextSourceFetchers); + } + + /** + * Generate a list of new scores for all of the documents, given the scoring context + * @param response search results to rescore + * @param rerankingContext extra information needed to score the search results; e.g. model id + * @param listener be async. recieves the list of new scores + */ + public abstract void rescoreSearchResponse( + final SearchResponse response, + final Map rerankingContext, + final ActionListener> listener + ); + + @Override + public void rerank( + final SearchResponse searchResponse, + final Map rerankingContext, + final ActionListener listener + ) { + try { + if (searchResponse.getHits().getTotalHits().value == 0) { + listener.onResponse(searchResponse); + return; + } + rescoreSearchResponse(searchResponse, rerankingContext, ActionListener.wrap(scores -> { + // Assign new scores + SearchHit[] hits = searchResponse.getHits().getHits(); + if (scores == null) { + throw new IllegalStateException("scores cannot be null"); + } + if (hits.length != scores.size()) { + throw new IllegalStateException("scores and hits are not the same length"); + } + // NOTE: Assumes that the new scores came back in the same order + for (int i = 0; i < hits.length; i++) { + hits[i].score(scores.get(i)); + } + // Re-sort by the new scores. Backwards comparison for desc ordering + Collections.sort(Arrays.asList(hits), (hit1, hit2) -> Float.compare(hit2.getScore(), hit1.getScore())); + // Reconstruct the search response, replacing the max score + SearchHits newHits = new SearchHits( + hits, + searchResponse.getHits().getTotalHits(), + hits[0].getScore(), + searchResponse.getHits().getSortFields(), + searchResponse.getHits().getCollapseField(), + searchResponse.getHits().getCollapseValues() + ); + SearchResponseSections newInternalResponse = new SearchResponseSections( + newHits, + searchResponse.getAggregations(), + searchResponse.getSuggest(), + searchResponse.isTimedOut(), + searchResponse.isTerminatedEarly(), + new SearchProfileShardResults(searchResponse.getProfileResults()), + searchResponse.getNumReducePhases(), + searchResponse.getInternalResponse().getSearchExtBuilders() + ); + SearchResponse newResponse = new SearchResponse( + newInternalResponse, + searchResponse.getScrollId(), + searchResponse.getTotalShards(), + searchResponse.getSuccessfulShards(), + searchResponse.getSkippedShards(), + searchResponse.getTook().millis(), + searchResponse.getPhaseTook(), + searchResponse.getShardFailures(), + searchResponse.getClusters(), + searchResponse.pointInTimeId() + ); + listener.onResponse(newResponse); + }, e -> { listener.onFailure(e); })); + } catch (Exception e) { + listener.onFailure(e); + } + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java new file mode 100644 index 000000000..ddb4a08fb --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank.context; + +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; + +/** + * Interface that gets context from some source and puts it in a map + * for a reranking processor to use + */ +public interface ContextSourceFetcher { + + /** + * Fetch the information needed in order to rerank. + * That could be as simple as grabbing a field from the search request or + * as complicated as a lookup to some external service + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + void fetchContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ); + + /** + * Get the name of the contextSourceFetcher. This will be used as the field + * name in the context config for the pipeline + * @return Name of the fetcher + */ + String getName(); + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java new file mode 100644 index 000000000..857c1dd46 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank.context; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ObjectPath; +import org.opensearch.search.SearchHit; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Context Source Fetcher that gets context from the search results (documents) + */ +@Log4j2 +@AllArgsConstructor +public class DocumentContextSourceFetcher implements ContextSourceFetcher { + + public static final String NAME = "document_fields"; + public static final String DOCUMENT_CONTEXT_LIST_FIELD = "document_context_list"; + + private final List contextFields; + + /** + * Fetch the information needed in order to rerank. + * That could be as simple as grabbing a field from the search request or + * as complicated as a lookup to some external service + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + @Override + public void fetchContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ) { + List contexts = new ArrayList<>(); + for (SearchHit hit : searchResponse.getHits()) { + StringBuilder ctx = new StringBuilder(); + for (String field : this.contextFields) { + ctx.append(contextFromSearchHit(hit, field)); + } + contexts.add(ctx.toString()); + } + listener.onResponse(new HashMap<>(Map.of(DOCUMENT_CONTEXT_LIST_FIELD, contexts))); + } + + private String contextFromSearchHit(final SearchHit hit, final String field) { + if (hit.getFields().containsKey(field)) { + Object fieldValue = hit.field(field).getValue(); + return String.valueOf(fieldValue); + } else if (hit.hasSource() && hit.getSourceAsMap().containsKey(field)) { + Object sourceValue = ObjectPath.eval(field, hit.getSourceAsMap()); + return String.valueOf(sourceValue); + } else { + log.warn( + String.format( + Locale.ROOT, + "Could not find field %s in document %s for reranking! Using the empty string instead.", + field, + hit.getId() + ) + ); + return ""; + } + } + + @Override + public String getName() { + return NAME; + } + + /** + * Create a document context source fetcher from list of field names provided by configuration + * @param config configuration object grabbed from parsed API request. Should be a list of strings + * @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed + */ + public static DocumentContextSourceFetcher create(Object config) { + if (!(config instanceof List)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME)); + } + List fields = (List) config; + if (fields.size() == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME)); + } + List fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList()); + return new DocumentContextSourceFetcher(fieldsAsStrings); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java new file mode 100644 index 000000000..d7463bcd1 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank.context; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ObjectPath; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; + +/** + * Context Source Fetcher that gets context from the rerank query ext. + */ +public class QueryContextSourceFetcher implements ContextSourceFetcher { + + public static final String NAME = "query_context"; + public static final String QUERY_TEXT_FIELD = "query_text"; + public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; + + @Override + public void fetchContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ) { + try { + // Get RerankSearchExt query-specific context map + List exts = searchRequest.source().ext(); + Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); + Map rerankContext = new HashMap<>(); + if (!params.containsKey(NAME)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "must specify %s", NAME)); + } + Object ctxObj = params.remove(NAME); + if (!(ctxObj instanceof Map)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a map", NAME)); + } + // Put query context into reranking context + @SuppressWarnings("unchecked") + Map ctxMap = (Map) ctxObj; + if (ctxMap.containsKey(QUERY_TEXT_FIELD)) { + // Case "query_text": "" + if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) + ); + } + rerankContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); + } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { + // Case "query_text_path": ser/de the query into a map and then find the text at the path specified + String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); + Map map = requestToMap(searchRequest); + // Get the text at the path + Object queryText = ObjectPath.eval(path, map); + if (!(queryText instanceof String)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD) + ); + } + rerankContext.put(QUERY_TEXT_FIELD, (String) queryText); + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) + ); + } + listener.onResponse(rerankContext); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + public String getName() { + return NAME; + } + + /** + * Convert a search request to a general map by streaming out as XContent and then back in, + * with the intention of representing the query as a user would see it + * @param request Search request to turn into xcontent + * @return Map representing the XContent-ified search request + * @throws IOException + */ + private static Map requestToMap(final SearchRequest request) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); + request.source().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); + Map map = parser.map(); + return map; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java new file mode 100644 index 000000000..3909c7499 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query.ext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Holds ext data from the query for reranking processors. Since + * there can be multiple kinds of rerank processors with different + * contexts, all we can assume is that there's keys and objects. + * e.g. ext might look like + * { + * "query": {blah}, + * "ext": { + * "rerank": { + * "query_context": { + * "query_text": "some question to rerank about" + * } + * } + * } + * } + * or + * { + * "query": {blah}, + * "ext": { + * "rerank": { + * "query_context": { + * "query_path": "query.neural.embedding.query_text" + * } + * } + * } + * } + */ +@AllArgsConstructor +public class RerankSearchExtBuilder extends SearchExtBuilder { + + public final static String PARAM_FIELD_NAME = "rerank"; + @Getter + protected Map params; + + public RerankSearchExtBuilder(StreamInput in) throws IOException { + params = in.readMap(); + } + + @Override + public String getWriteableName() { + return PARAM_FIELD_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(params); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + for (String key : this.params.keySet()) { + builder.field(key, this.params.get(key)); + } + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.params); + } + + @Override + public boolean equals(Object obj) { + return (obj instanceof RerankSearchExtBuilder) && params.equals(((RerankSearchExtBuilder) obj).params); + } + + /** + * Pick out the first RerankSearchExtBuilder from a list of SearchExtBuilders + * @param builders list of SearchExtBuilders + * @return the RerankSearchExtBuilder + */ + public static RerankSearchExtBuilder fromExtBuilderList(List builders) { + Optional b = builders.stream().filter(bldr -> bldr instanceof RerankSearchExtBuilder).findFirst(); + if (b.isPresent()) { + return (RerankSearchExtBuilder) b.get(); + } else { + return null; + } + } + + /** + * Parse XContent to rerankSearchExtBuilder + * @param parser parser parsing this searchExt + * @return RerankSearchExtBuilder represented by this searchExt + * @throws IOException if problems parsing + */ + public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException { + RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map) parser.map()); + return ans; + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index d9b816597..3749e63dc 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -325,6 +325,71 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } + public void testInferenceSimilarity_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createManyModelTensorOutputs(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { + final RuntimeException exception = new RuntimeException(); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(exception); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -352,4 +417,21 @@ private ModelTensorOutput createModelTensorOutput(final Map map) tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createManyModelTensorOutputs(final Float[] output) { + final List tensorsList = new ArrayList<>(); + for (Float score : output) { + List tensorList = new ArrayList<>(); + String name = "logits"; + Number[] data = new Number[] { score }; + long[] shape = new long[] { 1 }; + MLResultDataType dataType = MLResultDataType.FLOAT32; + MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name()); + ModelTensor tensor = ModelTensor.builder().name(name).data(data).shape(shape).dataType(mlResultDataType).build(); + tensorList.add(tensor); + tensorsList.add(new ModelTensors(tensorList)); + } + ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + return modelTensorOutput; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java new file mode 100644 index 000000000..ea37b2afb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.OpenSearchParseException; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class RerankProcessorFactoryTests extends OpenSearchTestCase { + + final String TAG = "default-tag"; + final String DESC = "processor description"; + + private RerankProcessorFactory factory; + + @Mock + private MLCommonsClientAccessor clientAccessor; + + @Mock + private PipelineContext pipelineContext; + + @Before + public void setup() { + pipelineContext = mock(PipelineContext.class); + clientAccessor = mock(MLCommonsClientAccessor.class); + factory = new RerankProcessorFactory(clientAccessor); + } + + public void testRerankProcessorFactory_whenEmptyConfig_thenFail() { + Map config = new HashMap<>(Map.of()); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_whenNonExistentType_thenFail() { + Map config = new HashMap<>( + Map.of("jpeo rvgh we iorgn", Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")) + ); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testCrossEncoder_whenCorrectParams_thenSuccessful() { + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof MLOpenSearchRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testCrossEncoder_whenMessyConfig_thenSuccessful() { + Map config = new HashMap<>( + Map.of( + "poafn aorr;anv", + Map.of(";oawhls", "aowirhg "), + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof MLOpenSearchRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testCrossEncoder_whenMessyContext_thenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>( + Map.of( + DocumentContextSourceFetcher.NAME, + new ArrayList<>(List.of("text_representation")), + "pqiohg rpowierhg", + "pw;oith4pt3ih go" + ) + ) + ) + ); + assertThrows( + String.format(Locale.ROOT, "unrecognized context field: %s", "pqiohg rpowierhg"), + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testCrossEncoder_whenEmptySubConfig_thenFail() { + Map config = new HashMap<>(Map.of(RerankType.ML_OPENSEARCH.getLabel(), Map.of())); + assertThrows( + String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), + OpenSearchParseException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testCrossEncoder_whenNoContextField_thenFail() { + Map config = new HashMap<>( + Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id"))) + ); + assertThrows( + String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), + OpenSearchParseException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testCrossEncoder_whenNoModelId_thenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + assertThrows( + String.format(Locale.ROOT, "[%s] required property is missing", MLOpenSearchRerankProcessor.MODEL_ID_FIELD), + OpenSearchParseException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testCrossEncoder_whenBadContextDocField_thenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, "text_representation")) + ) + ); + assertThrows( + String.format(Locale.ROOT, "%s must be a list of strings", DocumentContextSourceFetcher.NAME), + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testCrossEncoder_whenEmptyContextDocField_thenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>())) + ) + ); + assertThrows( + String.format(Locale.ROOT, "%s must be nonempty", DocumentContextSourceFetcher.NAME), + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java new file mode 100644 index 000000000..7bde28f7b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -0,0 +1,138 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import com.google.common.collect.ImmutableList; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { + + private final static String PIPELINE_NAME = "rerank-mlos-pipeline"; + private final static String INDEX_NAME = "rerank-test"; + private final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; + private final static String TEXT_REP_2 = "Fish like to eat plankton"; + private final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; + private String modelId; + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + deleteModel(modelId); + deleteSearchPipeline(PIPELINE_NAME); + deleteIndex(INDEX_NAME); + } + + @Before + @SneakyThrows + public void setup() { + modelId = uploadTextSimilarityModel(); + loadModel(modelId); + } + + @SneakyThrows + public void testCrossEncoderRerankProcessor() { + createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); + setupIndex(); + runQueries(); + } + + private String uploadTextSimilarityModel() throws Exception { + String requestBody = Files.readString( + Path.of(classLoader.getResource("processor/UploadTextSimilarityModelRequestBody.json").toURI()) + ); + return registerModelGroupAndUploadModel(requestBody); + } + + private void setupIndex() throws Exception { + createIndexWithConfiguration(INDEX_NAME, INDEX_CONFIG, PIPELINE_NAME); + Response response1 = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(String.format(LOCALE, "{\"text_representation\": \"%s\"}", TEXT_REP_1)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Response response2 = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(String.format(LOCALE, "{\"text_representation\": \"%s\"}", TEXT_REP_2)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response1.getEntity()), + false + ); + assertEquals("created", map.get("result")); + map = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(response2.getEntity()), false); + assertEquals("created", map.get("result")); + } + + private void runQueries() throws Exception { + Map response1 = search("What do fish eat?"); + @SuppressWarnings("unchecked") + List> hits = (List>) ((Map) response1.get("hits")).get("hits"); + @SuppressWarnings("unchecked") + Map hit0Source = (Map) hits.get(0).get("_source"); + assert ((String) hit0Source.get("text_representation")).equals(TEXT_REP_2); + @SuppressWarnings("unchecked") + Map hit1Source = (Map) hits.get(1).get("_source"); + assert ((String) hit1Source.get("text_representation")).equals(TEXT_REP_1); + + Map response2 = search("Who loves fish?"); + @SuppressWarnings("unchecked") + List> hits2 = (List>) ((Map) response2.get("hits")).get("hits"); + @SuppressWarnings("unchecked") + Map hit2Source = (Map) hits2.get(0).get("_source"); + assert ((String) hit2Source.get("text_representation")).equals(TEXT_REP_1); + @SuppressWarnings("unchecked") + Map hit3Source = (Map) hits2.get(1).get("_source"); + assert ((String) hit3Source.get("text_representation")).equals(TEXT_REP_2); + } + + private Map search(String queryText) throws Exception { + String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_context\": {\"query_text\":\"%s\"}}}}"; + String jsonQuery = String.format(LOCALE, jsonQueryFrame, queryText); + log.info(jsonQuery); + Request request = new Request("POST", "/" + INDEX_NAME + "/_search"); + request.addParameter("search_pipeline", PIPELINE_NAME); + request.setJsonEntity(jsonQuery); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + + return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java new file mode 100644 index 000000000..50d0cf2bc --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -0,0 +1,342 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.test.OpenSearchTestCase; + +public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { + + @Mock + private SearchRequest request; + + private SearchResponse response; + + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private PipelineContext pipelineContext; + + @Mock + private PipelineProcessingContext ppctx; + + private RerankProcessorFactory factory; + + private MLOpenSearchRerankProcessor processor; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + factory = new RerankProcessorFactory(mlCommonsClientAccessor); + Map config = new HashMap<>( + Map.of( + RerankType.ML_OPENSEARCH.getLabel(), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + processor = (MLOpenSearchRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for reranking with a cross encoder", + false, + config, + pipelineContext + ); + } + + private void setupParams(Map params) { + SearchSourceBuilder ssb = new SearchSourceBuilder(); + NeuralQueryBuilder nqb = new NeuralQueryBuilder(); + nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins"); + ssb.query(nqb); + List exts = List.of( + new RerankSearchExtBuilder(new HashMap<>(Map.of(QueryContextSourceFetcher.NAME, new HashMap<>(params)))) + ); + ssb.ext(exts); + doReturn(ssb).when(request).source(); + } + + private void setupSimilarityRescoring() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + List scores = List.of(1f, 2f, 3f); + listener.onResponse(scores); + return null; + }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + } + + private void setupSearchResults() throws IOException { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("text_representation", "source passage") + .endObject(); + SearchHit sourceHit = new SearchHit(0, "0", Map.of(), Map.of()); + sourceHit.sourceRef(BytesReference.bytes(sourceContent)); + sourceHit.score(1.5f); + + DocumentField field = new DocumentField("text_representation", List.of("field passage")); + SearchHit fieldHit = new SearchHit(1, "1", Map.of("text_representation", field), Map.of()); + fieldHit.score(1.7f); + + SearchHit nullHit = new SearchHit(2, "2", Map.of(), Map.of()); + nullHit.score(0f); + + SearchHit[] hitArray = new SearchHit[] { fieldHit, sourceHit, nullHit }; + TotalHits totalHits = new TotalHits(3, TotalHits.Relation.EQUAL_TO); + + SearchHits searchHits = new SearchHits(hitArray, totalHits, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new Clusters(1, 1, 0), null); + } + + public void testRerankContext_whenQueryText_thenSucceed() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text")); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().containsKey(QueryContextSourceFetcher.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(QueryContextSourceFetcher.QUERY_TEXT_FIELD).equals("query text")); + } + + public void testRerankContext_whenQueryTextPath_thenSucceed() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().containsKey(QueryContextSourceFetcher.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(QueryContextSourceFetcher.QUERY_TEXT_FIELD).equals("Question about dolphins")); + } + + public void testRerankContext_whenQueryTextAndPath_thenFail() throws IOException { + setupParams( + Map.of( + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, + "query.neural.embedding.query_text", + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text" + ) + ); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + "Cannot specify both \"" + + QueryContextSourceFetcher.QUERY_TEXT_FIELD + + "\" and \"" + + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + + "\"" + )); + } + + public void testRerankContext_whenNoQueryInfo_thenFail() throws IOException { + setupParams(Map.of()); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + "Must specify either \"" + + QueryContextSourceFetcher.QUERY_TEXT_FIELD + + "\" or \"" + + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + + "\"" + )); + } + + public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateRerankingContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field")); + } + + public void testRescoreSearchResponse_HappyPath() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + Map scoringContext = Map.of( + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text", + DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD, + new ArrayList<>(List.of("dummy", "dummy", "dummy")) + ); + processor.rescoreSearchResponse(response, scoringContext, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 3); + assert (argCaptor.getValue().get(0) == 1f); + assert (argCaptor.getValue().get(1) == 2f); + assert (argCaptor.getValue().get(2) == 3f); + } + + public void testRescoreSearchResponse_whenNoContextList_thenFail() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + Map scoringContext = Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text"); + processor.rescoreSearchResponse(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalStateException); + assert (argCaptor.getValue() + .getMessage() + .equals( + String.format( + Locale.ROOT, + "No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + DocumentContextSourceFetcher.NAME + ) + )); + } + + public void testRerank_HappyPath() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map scoringContext = Map.of( + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text", + DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD, + new ArrayList<>(List.of("dummy", "dummy", "dummy")) + ); + processor.rerank(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse rsp = argCaptor.getValue(); + assert (rsp.getHits().getAt(0).docId() == 2); + assert (rsp.getHits().getAt(0).getScore() == 3f); + assert (rsp.getHits().getAt(1).docId() == 0); + assert (rsp.getHits().getAt(1).getScore() == 2f); + assert (rsp.getHits().getAt(2).docId() == 1); + assert (rsp.getHits().getAt(2).getScore() == 1f); + } + + public void testRerank_whenScoresAndHitsHaveDiffLengths_thenFail() throws IOException { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + List scores = List.of(1f, 2f); + listener.onResponse(scores); + return null; + }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map scoringContext = Map.of( + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text", + DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD, + new ArrayList<>(List.of("dummy", "dummy", "dummy")) + ); + processor.rerank(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("scores and hits are not the same length")); + } + + public void testBasics() throws IOException { + assert (processor.getTag().equals("rerank processor")); + assert (processor.getDescription().equals("processor for reranking with a cross encoder")); + assert (!processor.isIgnoreFailure()); + assertThrows( + "Use asyncProcessResponse unless you can guarantee to not deadlock yourself", + UnsupportedOperationException.class, + () -> processor.processResponse(request, response) + ); + } + + public void testProcessResponseAsync() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text")); + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + processor.processResponseAsync(request, response, ppctx, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse rsp = argCaptor.getValue(); + assert (rsp.getHits().getAt(0).docId() == 2); + assert (rsp.getHits().getAt(0).getScore() == 3f); + assert (rsp.getHits().getAt(1).docId() == 0); + assert (rsp.getHits().getAt(1).getScore() == 2f); + assert (rsp.getHits().getAt(2).docId() == 1); + assert (rsp.getHits().getAt(2).getScore() == 1f); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java new file mode 100644 index 000000000..ea0af1eb5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query.ext; + +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.test.OpenSearchTestCase; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class RerankSearchExtBuilderTests extends OpenSearchTestCase { + + Map params; + + @Before + public void setup() { + params = Map.of("query_context", Map.of("query_text", "question about the meaning of life, the universe, and everything")); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry( + SearchExtBuilder.class, + new ParseField(RerankSearchExtBuilder.PARAM_FIELD_NAME), + parser -> RerankSearchExtBuilder.parse(parser) + ) + ) + ); + } + + public void testStreaming() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + b1.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + RerankSearchExtBuilder b2 = new RerankSearchExtBuilder(in); + assert (b2.getParams().equals(params)); + assert (b1.equals(b2)); + } + + public void testToXContent() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params)); + XContentBuilder builder = XContentType.JSON.contentBuilder(); + builder.startObject(); + b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + builder.endObject(); + String extString = builder.toString(); + log.info(extString); + XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); + SearchExtBuilder b2 = parser.namedObject(SearchExtBuilder.class, RerankSearchExtBuilder.PARAM_FIELD_NAME, parser); + assert (b2 instanceof RerankSearchExtBuilder); + RerankSearchExtBuilder b3 = (RerankSearchExtBuilder) b2; + log.info(b1.getParams().toString()); + log.info(b3.getParams().toString()); + assert (b3.getParams().equals(params)); + assert (b1.equals(b3)); + } + + public void testPullFromListOfExtBuilders() { + RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params); + SearchExtBuilder otherBuilder = mock(SearchExtBuilder.class); + assert (!builder.equals(otherBuilder)); + List builders1 = List.of(otherBuilder, builder); + List builders2 = List.of(otherBuilder); + List builders3 = List.of(); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders1).equals(builder)); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders2) == null); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders3) == null); + } + + public void testHash() { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + RerankSearchExtBuilder b2 = new RerankSearchExtBuilder(params); + RerankSearchExtBuilder b3 = new RerankSearchExtBuilder(Map.of()); + assert (b1.hashCode() == b2.hashCode()); + assert (b1.hashCode() != b3.hashCode()); + assert (!b1.equals(b3)); + } + + public void testWriteableName() { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + assert (b1.getWriteableName().equals(RerankSearchExtBuilder.PARAM_FIELD_NAME)); + } +} diff --git a/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json b/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json new file mode 100644 index 000000000..fc6cfd124 --- /dev/null +++ b/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json @@ -0,0 +1,15 @@ +{ + "description": "Pipeline for reranking with a cross encoder", + "response_processors": [ + { + "rerank": { + "ml_opensearch": { + "model_id": "%s" + }, + "context": { + "document_fields": ["text_representation"] + } + } + } + ] +} diff --git a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json new file mode 100644 index 000000000..3c23f6f21 --- /dev/null +++ b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json @@ -0,0 +1,16 @@ +{ + "name": "ms-marco-TinyBERT-L-2-v2", + "version": "1.0.0", + "function_name": "TEXT_SIMILARITY", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_group_id": "%s", + "model_content_hash_value": "90e39a926101d1a4e542aade0794319404689b12acfd5d7e65c03d91c668b5cf", + "model_config": { + "model_type": "bert", + "embedding_dimension": 1, + "framework_type": "huggingface_transformers", + "all_config": "nobody will read this" + }, + "url": "https://github.com/opensearch-project/ml-commons/blob/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-torch_script.zip?raw=true" +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 04b0fcb51..ffbbed2bc 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -331,6 +331,23 @@ protected void createSearchRequestProcessor(final String modelId, final String p assertEquals("true", node.get("acknowledged").toString()); } + protected void createSearchPipelineViaConfig(String modelId, String pipelineName, String configPath) throws Exception { + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + toHttpEntity(String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(configPath).toURI())), modelId)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map node = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(pipelineCreateResponse.getEntity()), + false + ); + assertEquals("true", node.get("acknowledged").toString()); + } + /** * Get the number of documents in a particular index *