From 8632f9c3c314b7a9a6e39192527dbe1bba904184 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 18 Dec 2024 13:47:20 -0800 Subject: [PATCH] Integrate explainability for hybrid query into RRF processor Signed-off-by: Martin Gaievski --- .../AbstractScoreHybridizationProcessor.java | 65 +++++++ .../processor/NormalizationProcessor.java | 36 +--- .../neuralsearch/processor/RRFProcessor.java | 19 +- .../RRFScoreCombinationTechnique.java | 11 +- .../RRFNormalizationTechnique.java | 71 ++++++-- .../processor/RRFProcessorTests.java | 33 ++++ .../query/HybridQueryExplainIT.java | 162 ++++++++++++++++-- .../neuralsearch/BaseNeuralSearchIT.java | 24 ++- 8 files changed, 339 insertions(+), 82 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java new file mode 100644 index 000000000..456e8415a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import java.util.Optional; + +/** + * Base class for all score hybridization processors. This class is responsible for executing the score hybridization process. + * It is a pipeline processor that is executed after the query phase and before the fetch phase. + */ +public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor { + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is no pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty()); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult + * @param searchPhaseContext + * @param requestContextOptional + * @param + */ + abstract void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2008ae97..2b1a28d01 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -19,9 +19,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; -import org.opensearch.search.internal.SearchContext; import org.opensearch.search.pipeline.PipelineProcessingContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; import lombok.AllArgsConstructor; @@ -33,7 +31,7 @@ */ @Log4j2 @AllArgsConstructor -public class NormalizationProcessor implements SearchPhaseResultsProcessor { +public class NormalizationProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "normalization-processor"; private final String tag; @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { private final ScoreCombinationTechnique combinationTechnique; private final NormalizationProcessorWorkflow normalizationWorkflow; - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor. This method is called when there is no pipeline context - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); - } - - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline - * @param - */ - @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext, - final PipelineProcessingContext requestContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); - } - - private void prepareAndExecuteNormalizationWorkflow( + void hybridizeScores( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional requestContextOptional diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index ca67f2d1c..100cf9fc6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -12,11 +12,12 @@ import java.util.Objects; import java.util.Optional; +import com.google.common.annotations.VisibleForTesting; import lombok.Getter; -import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.action.search.QueryPhaseResultConsumer; @@ -25,7 +26,6 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -39,7 +39,7 @@ */ @Log4j2 @AllArgsConstructor -public class RRFProcessor implements SearchPhaseResultsProcessor { +public class RRFProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "score-ranker-processor"; @Getter @@ -57,9 +57,10 @@ public class RRFProcessor implements SearchPhaseResultsProcessor { * @param searchPhaseContext {@link SearchContext} */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional ) { if (shouldSkipProcessor(searchPhaseResult)) { log.debug("Query results are not compatible with RRF processor"); @@ -67,7 +68,8 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - + boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) + && searchPhaseContext.getRequest().source().explain(); // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending // on coming from NormalizationProcessor or RRFProcessor NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() @@ -75,7 +77,8 @@ public void process( .fetchSearchResultOptional(fetchSearchResult) .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) - .explain(false) + .explain(explain) + .pipelineProcessingContext(requestContextOptional.orElse(null)) .build(); normalizationWorkflow.execute(normalizationExecuteDTO); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java index befe14dda..1a58fa0f3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -6,15 +6,19 @@ import lombok.ToString; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import java.util.List; import java.util.Map; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; + @Log4j2 /** * Abstracts combination of scores based on reciprocal rank fusion algorithm */ @ToString(onlyExplicitlyIncluded = true) -public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; @@ -29,4 +33,9 @@ public float combine(final float[] scores) { } return sumScores; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, List.of()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index 16ef83d05..4cc773592 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -6,27 +6,34 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Locale; import java.util.Set; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; import org.apache.commons.lang3.Range; import org.apache.commons.lang3.math.NumberUtils; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.ToString; import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts calculation of rank scores for each document returned as part of * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. */ @ToString(onlyExplicitlyIncluded = true) -public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; public static final int DEFAULT_RANK_CONSTANT = 60; @@ -58,21 +65,49 @@ public RRFNormalizationTechnique(final Map params, final ScoreNo public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { - if (Objects.isNull(compoundQueryTopDocs)) { - continue; - } - List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - for (TopDocs topDocs : topDocsPerSubQuery) { - int docsCountPerSubQuery = topDocs.scoreDocs.length; - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - for (int j = 0; j < docsCountPerSubQuery; j++) { - // using big decimal approach to minimize error caused by floating point ops - // score = 1.f / (float) (rankConstant + j + 1)) - scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP) - .floatValue(); - } - } + processTopDocs(compoundQueryTopDocs, (docId, score) -> {}); + } + } + + @Override + public String describe() { + return String.format(Locale.ROOT, "%s, rank_constant %s", TECHNIQUE_NAME, rankConstant); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + processTopDocs( + compoundQueryTopDocs, + (docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score) + ); } + + return getDocIdAtQueryForNormalization(normalizedScores, this); + } + + private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer scoreProcessor) { + if (Objects.isNull(compoundQueryTopDocs)) { + return; + } + + compoundQueryTopDocs.getTopDocs().forEach(topDocs -> { + IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> { + float normalizedScore = calculateNormalizedScore(position); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard( + topDocs.scoreDocs[position].doc, + compoundQueryTopDocs.getSearchShard() + ); + scoreProcessor.accept(docIdAtSearchShard, normalizedScore); + topDocs.scoreDocs[position].score = normalizedScore; + }); + }); + } + + private float calculateNormalizedScore(int position) { + return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue(); } private int getRankConstant(final Map params) { @@ -96,7 +131,7 @@ private void validateRankConstant(final int rankConstant) { } } - public static int getParamAsInteger(final Map parameters, final String fieldName) { + private static int getParamAsInteger(final Map parameters, final String fieldName) { try { return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); } catch (NumberFormatException e) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java index b7764128f..753c0b8fe 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.OriginalIndices; @@ -111,6 +112,10 @@ public void testProcess_whenValidHybridInput_thenSucceed() { when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder()); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); @@ -177,6 +182,34 @@ public void testGetFetchSearchResults() { assertFalse(result.isPresent()); } + @SneakyThrows + public void testProcess_whenExplainIsTrue_thenExplanationIsAdded() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + // Capture the actual request + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass( + NormalizationProcessorWorkflowExecuteRequest.class + ); + verify(mockNormalizationWorkflow).execute(requestCaptor.capture()); + + // Verify the captured request + NormalizationProcessorWorkflowExecuteRequest capturedRequest = requestCaptor.getValue(); + assertTrue(capturedRequest.isExplain()); + assertNull(capturedRequest.getPipelineProcessingContext()); + } + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { ShardId shardId = new ShardId("index", "uuid", 0); OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index b7e4f753a..35ad2aac5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -58,7 +58,8 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); - private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String NORMALIZATION_SEARCH_PIPELINE = "normalization-search-pipeline"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[768]; @@ -78,7 +79,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -95,7 +96,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -187,7 +188,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); assertEquals(3, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -196,7 +197,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipeline( - SEARCH_PIPELINE, + NORMALIZATION_SEARCH_PIPELINE, NORMALIZATION_TECHNIQUE_L2, DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), @@ -217,7 +218,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() hybridQueryBuilder, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // basic sanity check for search hits @@ -322,7 +323,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(0, getListOfValues(explanationsHit4, "details").size()); assertTrue((double) explanationsHit4.get("value") > 0.0f); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -331,7 +332,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with normalization processor, no explanation response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -348,7 +349,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -463,7 +464,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe assertEquals("boost", explanationsHit3Details.get("description")); assertEquals(0, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -472,7 +473,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -483,7 +484,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -521,7 +522,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -530,7 +531,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); @@ -543,7 +544,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -581,7 +582,136 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenRRFProcessor_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, true); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain, hit 1 + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "rrf combination of:"; + assertEquals(expectedTopLevelDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + // two sub-queries meaning we do have two detail objects with separate query level details + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant 60 normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit1.get("description")); + assertTrue((double) explanationsHit1.get("value") > 0.5f); + assertEquals(0, ((List) explanationsHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); + assertEquals("rrf, rank_constant 60 normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals("within top 10", explanationsHit2.get("description")); + assertTrue((double) explanationsHit2.get("value") > 0.0f); + assertEquals(0, ((List) explanationsHit2.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit2.get("description")); + List> hit2Details = getListOfValues(explanationForHit2, "details"); + assertEquals(2, hit2Details.size()); + + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant 60 normalization of:", hit2DetailsForHit1.get("description")); + assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue((double) hit2DetailsForHit2.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant 60 normalization of:", hit2DetailsForHit2.get("description")); + assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit3.get("description")); + List> hit3Details = getListOfValues(explanationForHit3, "details"); + assertEquals(1, hit3Details.size()); + + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > .0f); + assertEquals("rrf, rank_constant 60 normalization of:", hit3DetailsForHit1.get("description")); + assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); + + Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); + assertEquals("within top 10", explanationsHit3.get("description")); + assertEquals(0, getListOfValues(explanationsHit3, "details").size()); + assertTrue((double) explanationsHit3.get("value") > 0.0f); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = getValueByKey(searchHit4, "_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(1, hit4Details.size()); + + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant 60 normalization of:", hit4DetailsForHit1.get("description")); + assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); + + Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit4.get("description")); + assertEquals(0, getListOfValues(explanationsHit4, "details").size()); + assertTrue((double) explanationsHit4.get("value") > 0.0f); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, RRF_SEARCH_PIPELINE); } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 5107296c3..46e835cd0 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -1472,7 +1472,12 @@ protected enum ProcessorType { @SneakyThrows protected void createDefaultRRFSearchPipeline() { - String requestBody = XContentFactory.jsonBuilder() + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, false); + } + + @SneakyThrows + protected void createRRFSearchPipeline(final String pipelineName, boolean addExplainResponseProcessor) { + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .field("description", "Post processor for hybrid search") .startArray("phase_results_processors") @@ -1483,14 +1488,23 @@ protected void createDefaultRRFSearchPipeline() { .endObject() .endObject() .endObject() - .endArray() - .endObject() - .toString(); + .endArray(); + + if (addExplainResponseProcessor) { + builder.startArray("response_processors") + .startObject() + .startObject("explanation_response_processor") + .endObject() + .endObject() + .endArray(); + } + + String requestBody = builder.endObject().toString(); makeRequest( client(), "PUT", - String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + String.format(LOCALE, "/_search/pipeline/%s", pipelineName), null, toHttpEntity(String.format(LOCALE, requestBody)), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))