From a7777027fc549ee31d5b6f48a23613d5291190e9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 9 Aug 2024 09:34:58 -0700 Subject: [PATCH 01/11] Working draft, publish basic normalization results Signed-off-by: Martin Gaievski --- .../neuralsearch/plugin/NeuralSearch.java | 7 +- .../processor/CompoundTopDocs.java | 29 ++++-- .../processor/DocIdAtQueryPhase.java | 9 ++ .../processor/NormalizationProcessor.java | 30 +++++- .../NormalizationProcessorWorkflow.java | 58 ++++++++++-- ...zationProcessorWorkflowExecuteRequest.java | 29 ++++++ .../processor/ProcessorExplainDto.java | 21 +++++ .../processor/ProcessorExplainPublisher.java | 94 +++++++++++++++++++ .../neuralsearch/processor/SearchShard.java | 8 ++ ...ithmeticMeanScoreCombinationTechnique.java | 6 ++ .../ScoreCombinationTechnique.java | 4 + .../processor/combination/ScoreCombiner.java | 49 +++++++++- .../ProcessorExplainPublisherFactory.java | 26 +++++ .../L2ScoreNormalizationTechnique.java | 48 ++++++++++ .../MinMaxScoreNormalizationTechnique.java | 63 +++++++++++++ .../ScoreNormalizationTechnique.java | 8 ++ .../normalization/ScoreNormalizer.java | 12 +++ .../neuralsearch/query/HybridQueryWeight.java | 22 ++++- .../processor/CompoundTopDocsTests.java | 18 +++- .../NormalizationProcessorTests.java | 6 +- .../ScoreCombinationTechniqueTests.java | 11 ++- .../ScoreNormalizationTechniqueTests.java | 20 ++-- .../L2ScoreNormalizationTechniqueTests.java | 26 +++-- ...inMaxScoreNormalizationTechniqueTests.java | 26 +++-- .../query/HybridQueryWeightTests.java | 6 +- 25 files changed, 582 insertions(+), 54 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 8b173ba81..5391faa14 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -32,12 +32,14 @@ import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.ProcessorExplainPublisherFactory; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; @@ -80,6 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + public static final String PROCESSOR_EXPLAIN = "processor_explain"; @Override public Collection createComponents( @@ -181,7 +184,9 @@ public Map topDocs; @Setter private List scoreDocs; + @Getter + private SearchShard searchShard; - public CompoundTopDocs(final TotalHits totalHits, final List topDocs, final boolean isSortEnabled) { - initialize(totalHits, topDocs, isSortEnabled); + public CompoundTopDocs( + final TotalHits totalHits, + final List topDocs, + final boolean isSortEnabled, + final SearchShard searchShard + ) { + initialize(totalHits, topDocs, isSortEnabled, searchShard); } - private void initialize(TotalHits totalHits, List topDocs, boolean isSortEnabled) { + private void initialize(TotalHits totalHits, List topDocs, boolean isSortEnabled, SearchShard searchShard) { this.totalHits = totalHits; this.topDocs = topDocs; scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled); + this.searchShard = searchShard; } /** @@ -72,14 +82,21 @@ private void initialize(TotalHits totalHits, List topDocs, boolean isSo * 6, 0.15 * 0, 9549511920.4881596047 */ - public CompoundTopDocs(final TopDocs topDocs) { + public CompoundTopDocs(final QuerySearchResult querySearchResult) { + final TopDocs topDocs = querySearchResult.topDocs().topDocs; + final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget(); boolean isSortEnabled = false; if (topDocs instanceof TopFieldDocs) { isSortEnabled = true; } ScoreDoc[] scoreDocs = topDocs.scoreDocs; if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) { - initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled); + SearchShard searchShard = new SearchShard( + searchShardTarget.getIndex(), + searchShardTarget.getShardId().id(), + searchShardTarget.getNodeId() + ); + initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled, searchShard); return; } // skipping first two elements, it's a start-stop element and delimiter for first series @@ -103,7 +120,7 @@ public CompoundTopDocs(final TopDocs topDocs) { scoreDocList.add(scoreDoc); } } - initialize(topDocs.totalHits, topDocsList, isSortEnabled); + initialize(topDocs.totalHits, topDocsList, isSortEnabled, searchShard); } private List cloneLargestScoreDocs(final List docs, boolean isSortEnabled) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java b/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java new file mode 100644 index 000000000..c81dd5f96 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java @@ -0,0 +1,9 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +//public record DocIdAtQueryPhase(Integer docId, SearchShardTarget searchShardTarget) { +public record DocIdAtQueryPhase(int docId, SearchShard searchShard) { +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0563c92a0..6c6327b3b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -20,6 +20,7 @@ import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; @@ -51,6 +52,23 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { public void process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext + ) { + doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.empty()); + } + + @Override + public void process( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + PipelineProcessingContext requestContext + ) { + doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + } + + private void doProcessStuff( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional ) { if (shouldSkipProcessor(searchPhaseResult)) { log.debug("Query results are not compatible with normalization processor"); @@ -58,7 +76,17 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); + boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) + && searchPhaseContext.getRequest().source().explain(); + NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .explain(explain) + .pipelineProcessingContext(requestContextOptional.orElse(null)) + .build(); + normalizationWorkflow.execute(request); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index c64f1c1f4..4e90beb43 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -13,6 +13,7 @@ import java.util.Optional; import java.util.stream.Collectors; +import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Sort; @@ -27,10 +28,13 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.query.QuerySearchResult; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; + +import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; @@ -57,22 +61,35 @@ public void execute( final ScoreNormalizationTechnique normalizationTechnique, final ScoreCombinationTechnique combinationTechnique ) { + NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResultOptional) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .explain(false) + .build(); + execute(request); + } + + public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { // save original state - List unprocessedDocIds = unprocessedDocIds(querySearchResults); + List unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); // pre-process data log.debug("Pre-process query results"); - List queryTopDocs = getQueryTopDocs(querySearchResults); + List queryTopDocs = getQueryTopDocs(request.getQuerySearchResults()); + + explain(request, queryTopDocs); // normalize log.debug("Do score normalization"); - scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique()); CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) - .scoreCombinationTechnique(combinationTechnique) - .querySearchResults(querySearchResults) - .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .scoreCombinationTechnique(request.getCombinationTechnique()) + .querySearchResults(request.getQuerySearchResults()) + .sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs)) .build(); // combine @@ -82,7 +99,33 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds); + updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); + } + + private void explain(NormalizationProcessorWorkflowExecuteRequest request, List queryTopDocs) { + if (request.isExplain()) { + // general description of techniques + String explanationDetailsMessage = String.format( + Locale.ROOT, + "%s, %s", + request.getNormalizationTechnique().describe(), + request.getCombinationTechnique().describe() + ); + + Explanation explanation = Explanation.match(0.0f, explanationDetailsMessage); + + // build final result object with all explain related information + if (Objects.nonNull(request.getPipelineProcessingContext())) { + ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() + .explanation(explanation) + .normalizedScoresByDocId(scoreNormalizer.explain(queryTopDocs, request.getNormalizationTechnique())) + .combinedScoresByDocId(scoreCombiner.explain(queryTopDocs, request.getCombinationTechnique())) + .build(); + // store explain object to pipeline context + PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); + pipelineProcessingContext.setAttribute(PROCESSOR_EXPLAIN, processorExplainDto); + } + } } /** @@ -93,7 +136,6 @@ public void execute( private List getQueryTopDocs(final List querySearchResults) { List queryTopDocs = querySearchResults.stream() .filter(searchResult -> Objects.nonNull(searchResult.topDocs())) - .map(querySearchResult -> querySearchResult.topDocs().topDocs) .map(CompoundTopDocs::new) .collect(Collectors.toList()); if (queryTopDocs.size() != querySearchResults.size()) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java new file mode 100644 index 000000000..8056bd100 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.query.QuerySearchResult; + +import java.util.List; +import java.util.Optional; + +@Builder +@AllArgsConstructor +@Getter +public class NormalizationProcessorWorkflowExecuteRequest { + final List querySearchResults; + final Optional fetchSearchResultOptional; + final ScoreNormalizationTechnique normalizationTechnique; + final ScoreCombinationTechnique combinationTechnique; + boolean explain; + final PipelineProcessingContext pipelineProcessingContext; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java new file mode 100644 index 000000000..1255779f6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import org.apache.lucene.search.Explanation; + +import java.util.Map; + +@AllArgsConstructor +@Builder +@Getter +public class ProcessorExplainDto { + Explanation explanation; + Map normalizedScoresByDocId; + Map combinedScoresByDocId; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java new file mode 100644 index 000000000..b6381f7ce --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.search.Explanation; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import java.util.Objects; + +import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; + +@Getter +@AllArgsConstructor +public class ProcessorExplainPublisher implements SearchResponseProcessor { + + public static final String TYPE = "processor_explain_publisher"; + + private final String description; + private final String tag; + private final boolean ignoreFailure; + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + return processResponse(request, response, null); + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { + if (Objects.nonNull(requestContext.getAttribute(PROCESSOR_EXPLAIN))) { + ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(PROCESSOR_EXPLAIN); + Explanation explanation = processorExplainDto.getExplanation(); + SearchHits searchHits = response.getHits(); + SearchHit[] searchHitsArray = searchHits.getHits(); + for (SearchHit searchHit : searchHitsArray) { + SearchShardTarget searchShardTarget = searchHit.getShard(); + DocIdAtQueryPhase docIdAtQueryPhase = new DocIdAtQueryPhase( + searchHit.docId(), + new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()) + ); + Explanation normalizedExplanation = Explanation.match( + 0.0f, + processorExplainDto.getNormalizedScoresByDocId().get(docIdAtQueryPhase) + ); + Explanation combinedExplanation = Explanation.match( + 0.0f, + processorExplainDto.getCombinedScoresByDocId().get(docIdAtQueryPhase) + ); + Explanation finalExplanation = Explanation.match( + searchHit.getScore(), + "combined explanation from processor and query: ", + explanation, + normalizedExplanation, + combinedExplanation, + searchHit.getExplanation() + ); + searchHit.explanation(finalExplanation); + } + // delete processor explain data to avoid double processing + // requestContext.setAttribute(PROCESSOR_EXPLAIN, null); + } + + return response; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return ignoreFailure; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java new file mode 100644 index 000000000..57c893c90 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -0,0 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +public record SearchShard(String index, int shardId, String nodeId) { +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 001f1670d..de03688c2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; @@ -29,6 +30,11 @@ public ArithmeticMeanScoreCombinationTechnique(final Map params, weights = scoreCombinationUtil.getWeights(params); } + @Override + public String describe() { + return String.format(Locale.ROOT, "combination technique %s [%s]", TECHNIQUE_NAME, "score = (score1 + score2 + ... + scoreN)/N"); + } + /** * Arithmetic mean method for combining scores. * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java index dbeabe94b..c04b24b51 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java @@ -12,4 +12,8 @@ public interface ScoreCombinationTechnique { * @return combined score */ float combine(final float[] scores); + + default String describe() { + return "generic score combination technique"; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index a4e39f448..30a32b63a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -26,6 +27,7 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; /** * Abstracts combination of scores in query search results. @@ -60,7 +62,7 @@ public class ScoreCombiner { * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", * other steps are same for all techniques. * - * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. + * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. */ public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from @@ -107,6 +109,7 @@ private void combineShardScores( updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs, topDocsPerSubQuery, + normalizedScoresPerDoc, combinedNormalizedScoresByDocId, sortedDocsIds, getDocIdSortFieldsMap(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sort), @@ -129,7 +132,7 @@ private boolean isSortOrderByScore(Sort sort) { } /** - * @param sort sort criteria + * @param sort sort criteria * @param topDocsPerSubQuery top docs per subquery * @return list of top field docs which is deduced by typcasting top docs to top field docs. */ @@ -149,9 +152,9 @@ private List getTopFieldDocs(final Sort sort, final List } /** - * @param compoundTopDocs top docs that represent on shard + * @param compoundTopDocs top docs that represent on shard * @param combinedNormalizedScoresByDocId docId to normalized scores map - * @param sort sort criteria + * @param sort sort criteria * @return map of docId and sort fields if sorting is enabled. */ private Map getDocIdSortFieldsMap( @@ -290,6 +293,7 @@ private Map combineScoresAndGetCombinedNormalizedScoresPerDocume private void updateQueryTopDocsWithCombinedScores( final CompoundTopDocs compoundQueryTopDocs, final List topDocsPerSubQuery, + Map normalizedScoresPerDoc, final Map combinedNormalizedScoresByDocId, final Collection sortedScores, Map docIdSortFieldMap, @@ -318,4 +322,41 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon } return new TotalHits(maxHits, totalHits); } + + public Map explain( + final List queryTopDocs, + ScoreCombinationTechnique combinationTechnique + ) { + Map explain = new HashMap<>(); + queryTopDocs.forEach(compoundQueryTopDocs -> explainByShard(combinationTechnique, compoundQueryTopDocs, explain)); + return explain; + } + + private void explainByShard( + final ScoreCombinationTechnique scoreCombinationTechnique, + final CompoundTopDocs compoundQueryTopDocs, + Map explain + ) { + if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { + return; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + + // - create map of normalized scores results returned from the single shard + Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery); + + // - create map of combined scores per doc id + Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + + normalizedScoresPerDoc.entrySet().stream().forEach(entry -> { + float[] srcScores = entry.getValue(); + float combinedScore = combinedNormalizedScoresByDocId.get(entry.getKey()); + explain.put( + new DocIdAtQueryPhase(entry.getKey(), compoundQueryTopDocs.getSearchShard()), + "source scores " + Arrays.toString(srcScores) + " combined score " + combinedScore + ); + }); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java new file mode 100644 index 000000000..2633a89ad --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import java.util.Map; + +public class ProcessorExplainPublisherFactory implements Processor.Factory { + + @Override + public SearchResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + Processor.PipelineContext pipelineContext + ) throws Exception { + return new ProcessorExplainPublisher(description, tag, ignoreFailure); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 2bb6bbed7..617dada10 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -5,14 +5,19 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.ToString; +import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; /** * Abstracts normalization of scores based on L2 method @@ -50,6 +55,49 @@ public void normalize(final List queryTopDocs) { } } + @Override + public String describe() { + return String.format( + Locale.ROOT, + "normalization technique %s [%s]", + TECHNIQUE_NAME, + "score = score/sqrt(score1^2 + score2^2 + ... + scoreN^2)" + ); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + Map> sourceScores = new HashMap<>(); + List normsPerSubquery = getL2Norm(queryTopDocs); + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + normalizedScores.computeIfAbsent( + new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), + k -> new ArrayList<>() + ).add(normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j))); + sourceScores.computeIfAbsent( + new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), + k -> new ArrayList<>() + ).add(scoreDoc.score); + } + } + } + Map explain = sourceScores.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { + List srcScores = entry.getValue(); + List normScores = normalizedScores.get(entry.getKey()); + return "source scores " + srcScores + " normalized scores " + normScores; + })); + return explain; + } + private List getL2Norm(final List queryTopDocs) { // find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries, // or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 4fdf3c0a6..ba7b8b20a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -4,9 +4,14 @@ */ package org.opensearch.neuralsearch.processor.normalization; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -15,6 +20,7 @@ import com.google.common.primitives.Floats; import lombok.ToString; +import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; /** * Abstracts normalization of scores based on min-max method @@ -63,6 +69,63 @@ public void normalize(final List queryTopDocs) { } } + @Override + public String describe() { + return String.format( + Locale.ROOT, + "normalization technique %s [%s]", + TECHNIQUE_NAME, + "score = (score - min_score)/(max_score - min_score)" + ); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + Map> sourceScores = new HashMap<>(); + + int numOfSubqueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) + .findAny() + .get() + .getTopDocs() + .size(); + // get min scores for each sub query + float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); + + // get max scores for each sub query + float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); + + // do normalization using actual score and min and max scores for corresponding sub query + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + normalizedScores.computeIfAbsent( + new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), + k -> new ArrayList<>() + ).add(normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j])); + sourceScores.computeIfAbsent( + new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), + k -> new ArrayList<>() + ).add(scoreDoc.score); + } + } + } + + Map explain = sourceScores.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { + List srcScores = entry.getValue(); + List normScores = normalizedScores.get(entry.getKey()); + return "source scores " + srcScores + " normalized scores " + normScores; + })); + return explain; + } + private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { float[] maxScores = new float[numOfSubqueries]; Arrays.fill(maxScores, Float.MIN_VALUE); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index 0b784c678..642bf284b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -5,8 +5,10 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.List; +import java.util.Map; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; /** * Abstracts normalization of scores in query search results. @@ -18,4 +20,10 @@ public interface ScoreNormalizationTechnique { * @param queryTopDocs original query results from multiple shards and multiple sub-queries */ void normalize(final List queryTopDocs); + + default String describe() { + return "score normalization technique"; + } + + Map explain(final List queryTopDocs); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 263115f8f..fbfba2b66 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -5,9 +5,11 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.List; +import java.util.Map; import java.util.Objects; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; public class ScoreNormalizer { @@ -25,4 +27,14 @@ public void normalizeScores(final List queryTopDocs, final Scor private boolean canQueryResultsBeNormalized(final List queryTopDocs) { return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0); } + + public Map explain( + final List queryTopDocs, + final ScoreNormalizationTechnique scoreNormalizationTechnique + ) { + if (canQueryResultsBeNormalized(queryTopDocs)) { + return scoreNormalizationTechnique.explain(queryTopDocs); + } + return Map.of(); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index dc1f5e112..08393922a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -149,7 +149,27 @@ public boolean isCacheable(LeafReaderContext ctx) { */ @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - throw new UnsupportedOperationException("Explain is not supported"); + boolean match = false; + double max = 0; + List subsOnNoMatch = new ArrayList<>(); + List subsOnMatch = new ArrayList<>(); + for (Weight wt : weights) { + Explanation e = wt.explain(context, doc); + if (e.isMatch()) { + match = true; + double score = e.getValue().doubleValue(); + subsOnMatch.add(e); + max = Math.max(max, score); + } else if (!match) { + subsOnNoMatch.add(e); + } + } + if (match) { + final String desc = "combination of:"; + return Explanation.match(max, desc, subsOnMatch); + } else { + return Explanation.noMatch("no matching clause", subsOnNoMatch); + } } @RequiredArgsConstructor diff --git a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java index 3b2f64063..eabc69894 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java @@ -14,6 +14,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; public class CompoundTopDocsTests extends OpenSearchQueryTestCase { + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { TopDocs topDocs1 = new TopDocs( @@ -28,7 +29,7 @@ public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { new ScoreDoc(5, RandomUtils.nextFloat()) } ); List topDocs = List.of(topDocs1, topDocs2); - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs, false); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs, false, SEARCH_SHARD); assertNotNull(compoundTopDocs); assertEquals(topDocs, compoundTopDocs.getTopDocs()); } @@ -45,7 +46,8 @@ public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() { new ScoreDoc(5, RandomUtils.nextFloat()) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(hybridQueryScoreTopDocs); assertNotNull(hybridQueryScoreTopDocs.getScoreDocs()); @@ -59,21 +61,27 @@ public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWit new ScoreDoc[] { new ScoreDoc(2, RandomUtils.nextFloat()), new ScoreDoc(4, RandomUtils.nextFloat()) } ); List topDocs = List.of(topDocs1, topDocs2); - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs, false); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs, false, SEARCH_SHARD); assertNotNull(compoundTopDocs); assertNotNull(compoundTopDocs.getScoreDocs()); assertEquals(2, compoundTopDocs.getScoreDocs().size()); } public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null, false); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + (List) null, + false, + SEARCH_SHARD + ); assertNotNull(compoundTopDocs); assertNull(compoundTopDocs.getScoreDocs()); CompoundTopDocs compoundTopDocsWithNullArray = new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), Arrays.asList(null, null), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocsWithNullArray); assertNotNull(compoundTopDocsWithNullArray.getScoreDocs()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index e93c9b9ec..5f45b14fe 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -179,6 +179,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio } SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -247,6 +248,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(1); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -408,6 +410,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -417,7 +420,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { @@ -495,6 +498,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); IllegalStateException exception = expectThrows( IllegalStateException.class, () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 918f3f45b..6ff6d9174 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -19,6 +19,8 @@ public class ScoreCombinationTechniqueTests extends OpenSearchTestCase { + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreCombiner scoreCombiner = new ScoreCombiner(); scoreCombiner.combineScores( @@ -46,7 +48,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(5, 0.001f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -57,7 +60,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -65,7 +69,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) ), - false + false, + SEARCH_SHARD ) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index 67abd552f..b2b0007f6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -18,6 +18,8 @@ public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); @@ -30,7 +32,8 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco new CompoundTopDocs( new TotalHits(1, TotalHits.Relation.EQUAL_TO), List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -61,7 +64,8 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } ) ), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -98,7 +102,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) ), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -147,7 +152,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -158,7 +164,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new ScoreDoc[] { new ScoreDoc(2, 2.2f), new ScoreDoc(4, 1.8f), new ScoreDoc(7, 0.9f), new ScoreDoc(9, 0.01f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -166,7 +173,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) ), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index ba4bfee0d..734f9bb57 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -18,6 +19,7 @@ */ public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); @@ -31,7 +33,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -46,7 +49,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc(4, l2Norm(scores[1], Arrays.asList(scores))) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -78,7 +82,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc(2, scoresQuery2[2]) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -101,7 +106,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc(2, l2Norm(scoresQuery2[2], Arrays.asList(scoresQuery2))) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -133,7 +139,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(2, scoresShard1and2Query3[2]) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -152,7 +159,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(15, scoresShard1and2Query3[6]) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -175,7 +183,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(2, l2Norm(scoresShard1and2Query3[2], Arrays.asList(scoresShard1and2Query3))) } ) ), - false + false, + SEARCH_SHARD ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( @@ -197,7 +206,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(15, l2Norm(scoresShard1and2Query3[6], Arrays.asList(scoresShard1and2Query3))) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index d0445f0ca..c7692b407 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -17,6 +18,7 @@ */ public class MinMaxScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { MinMaxScoreNormalizationTechnique normalizationTechnique = new MinMaxScoreNormalizationTechnique(); @@ -29,7 +31,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -42,7 +45,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -69,7 +73,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -87,7 +92,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -113,7 +119,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -124,7 +131,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -142,7 +150,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( @@ -154,7 +163,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 10d480475..0e32b5e78 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -19,6 +19,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Matches; import org.apache.lucene.search.MatchesIterator; @@ -146,7 +147,7 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { } @SneakyThrows - public void testExplain_whenCallExplain_thenFail() { + public void testExplain_whenCallExplain_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); @@ -171,7 +172,8 @@ public void testExplain_whenCallExplain_thenFail() { assertNotNull(weight); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - expectThrows(UnsupportedOperationException.class, () -> weight.explain(leafReaderContext, docId)); + Explanation explanation = weight.explain(leafReaderContext, docId); + assertNotNull(explanation); w.close(); reader.close(); From 50f3d2790e6cf11e49642c818d0bf82d54e044d9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 14 Oct 2024 10:25:54 -0700 Subject: [PATCH 02/11] Added Explanable interface for techniques Signed-off-by: Martin Gaievski --- .../processor/CompoundTopDocs.java | 10 ++-- .../processor/DocIdAtQueryPhase.java | 7 ++- .../processor/ExplainableTechnique.java | 32 +++++++++++ .../NormalizationProcessorWorkflow.java | 54 ++++++++++++------- ...ithmeticMeanScoreCombinationTechnique.java | 16 +++--- ...eometricMeanScoreCombinationTechnique.java | 10 +++- ...HarmonicMeanScoreCombinationTechnique.java | 10 +++- .../ScoreCombinationTechnique.java | 4 -- .../processor/combination/ScoreCombiner.java | 2 +- .../L2ScoreNormalizationTechnique.java | 34 ++++-------- .../MinMaxScoreNormalizationTechnique.java | 33 ++++-------- .../ScoreNormalizationTechnique.java | 8 --- .../normalization/ScoreNormalizer.java | 3 +- .../processor/util/ExplainUtils.java | 46 ++++++++++++++++ 14 files changed, 175 insertions(+), 94 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java index 4ea1c9b03..fe0ae6f44 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java @@ -85,17 +85,17 @@ private void initialize(TotalHits totalHits, List topDocs, boolean isSo public CompoundTopDocs(final QuerySearchResult querySearchResult) { final TopDocs topDocs = querySearchResult.topDocs().topDocs; final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget(); + SearchShard searchShard = new SearchShard( + searchShardTarget.getIndex(), + searchShardTarget.getShardId().id(), + searchShardTarget.getNodeId() + ); boolean isSortEnabled = false; if (topDocs instanceof TopFieldDocs) { isSortEnabled = true; } ScoreDoc[] scoreDocs = topDocs.scoreDocs; if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) { - SearchShard searchShard = new SearchShard( - searchShardTarget.getIndex(), - searchShardTarget.getShardId().id(), - searchShardTarget.getNodeId() - ); initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled, searchShard); return; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java b/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java index c81dd5f96..95d4fd7d9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java @@ -4,6 +4,11 @@ */ package org.opensearch.neuralsearch.processor; -//public record DocIdAtQueryPhase(Integer docId, SearchShardTarget searchShardTarget) { +/** + * Data class to store docId and search shard for a query. + * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. + * @param docId + * @param searchShard + */ public record DocIdAtQueryPhase(int docId, SearchShard searchShard) { } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java new file mode 100644 index 000000000..b2da43fc9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import java.util.List; +import java.util.Map; + +/** + * Abstracts explanation of score combination or normalization technique. + */ +public interface ExplainableTechnique { + + String GENERIC_DESCRIPTION_OF_TECHNIQUE = "generic score processing technique"; + + /** + * Returns a string with general description of the technique + */ + default String describe() { + return GENERIC_DESCRIPTION_OF_TECHNIQUE; + } + + /** + * Returns a map with explanation for each document id + * @param queryTopDocs collection of CompoundTopDocs for each shard result + * @return map of document per shard and corresponding explanation object + */ + default Map explain(final List queryTopDocs) { + return Map.of(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 4e90beb43..96f6bf2aa 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -103,29 +103,43 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) } private void explain(NormalizationProcessorWorkflowExecuteRequest request, List queryTopDocs) { - if (request.isExplain()) { - // general description of techniques - String explanationDetailsMessage = String.format( - Locale.ROOT, - "%s, %s", - request.getNormalizationTechnique().describe(), - request.getCombinationTechnique().describe() - ); + if (!request.isExplain()) { + return; + } + Explanation describedTechniqueForExplain = describeTechniqueForExplain(request); - Explanation explanation = Explanation.match(0.0f, explanationDetailsMessage); + // build final result object with all explain related information + if (Objects.nonNull(request.getPipelineProcessingContext())) { + Map explainedNormalization = scoreNormalizer.explain( + queryTopDocs, + (ExplainableTechnique) request.getNormalizationTechnique() + ); - // build final result object with all explain related information - if (Objects.nonNull(request.getPipelineProcessingContext())) { - ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() - .explanation(explanation) - .normalizedScoresByDocId(scoreNormalizer.explain(queryTopDocs, request.getNormalizationTechnique())) - .combinedScoresByDocId(scoreCombiner.explain(queryTopDocs, request.getCombinationTechnique())) - .build(); - // store explain object to pipeline context - PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(PROCESSOR_EXPLAIN, processorExplainDto); - } + ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() + .explanation(describedTechniqueForExplain) + .normalizedScoresByDocId(explainedNormalization) + .combinedScoresByDocId(scoreCombiner.explain(queryTopDocs, request.getCombinationTechnique())) + .build(); + // store explain object to pipeline context + PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); + pipelineProcessingContext.setAttribute(PROCESSOR_EXPLAIN, processorExplainDto); } + + } + + private static Explanation describeTechniqueForExplain(NormalizationProcessorWorkflowExecuteRequest request) { + // general description of techniques + ExplainableTechnique explainableCombinationTechnique = (ExplainableTechnique) request.getCombinationTechnique(); + ExplainableTechnique explainableNormalizationTechnique = (ExplainableTechnique) request.getNormalizationTechnique(); + String explanationDetailsMessage = String.format( + Locale.ROOT, + "%s, %s", + explainableNormalizationTechnique.describe(), + explainableCombinationTechnique.describe() + ); + + Explanation explanation = Explanation.match(0.0f, explanationDetailsMessage); + return explanation; } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index de03688c2..7ff6bea18 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -5,17 +5,19 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Set; import lombok.ToString; +import org.opensearch.neuralsearch.processor.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.util.ExplainUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on arithmetic mean method */ @ToString(onlyExplicitlyIncluded = true) -public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "arithmetic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; @@ -30,11 +32,6 @@ public ArithmeticMeanScoreCombinationTechnique(final Map params, weights = scoreCombinationUtil.getWeights(params); } - @Override - public String describe() { - return String.format(Locale.ROOT, "combination technique %s [%s]", TECHNIQUE_NAME, "score = (score1 + score2 + ... + scoreN)/N"); - } - /** * Arithmetic mean method for combining scores. * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) @@ -60,4 +57,9 @@ public float combine(final float[] scores) { } return combinedScore / sumOfWeights; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, weights); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index c4b6dfb3f..a0cd49136 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -9,12 +9,15 @@ import java.util.Set; import lombok.ToString; +import org.opensearch.neuralsearch.processor.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.util.ExplainUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on geometrical mean method */ @ToString(onlyExplicitlyIncluded = true) -public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "geometric_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; @@ -54,4 +57,9 @@ public float combine(final float[] scores) { } return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights); } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, weights); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index f5195f79f..7aabf0d61 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -9,12 +9,15 @@ import java.util.Set; import lombok.ToString; +import org.opensearch.neuralsearch.processor.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.util.ExplainUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on harmonic mean method */ @ToString(onlyExplicitlyIncluded = true) -public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "harmonic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; @@ -51,4 +54,9 @@ public float combine(final float[] scores) { } return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, weights); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java index c04b24b51..dbeabe94b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java @@ -12,8 +12,4 @@ public interface ScoreCombinationTechnique { * @return combined score */ float combine(final float[] scores); - - default String describe() { - return "generic score combination technique"; - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 30a32b63a..9ab74ea32 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -355,7 +355,7 @@ private void explainByShard( float combinedScore = combinedNormalizedScoresByDocId.get(entry.getKey()); explain.put( new DocIdAtQueryPhase(entry.getKey(), compoundQueryTopDocs.getSearchShard()), - "source scores " + Arrays.toString(srcScores) + " combined score " + combinedScore + String.format("source scores [%s], combined score [%s]", Arrays.toString(srcScores), combinedScore) ); }); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 617dada10..1ed3bf88a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -10,7 +10,6 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -18,12 +17,15 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; +import org.opensearch.neuralsearch.processor.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.util.ExplainUtils.getDocIdAtQueryPhaseStringMap; /** * Abstracts normalization of scores based on L2 method */ @ToString(onlyExplicitlyIncluded = true) -public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique { +public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "l2"; private static final float MIN_SCORE = 0.0f; @@ -55,14 +57,8 @@ public void normalize(final List queryTopDocs) { } } - @Override public String describe() { - return String.format( - Locale.ROOT, - "normalization technique %s [%s]", - TECHNIQUE_NAME, - "score = score/sqrt(score1^2 + score2^2 + ... + scoreN^2)" - ); + return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); } @Override @@ -79,23 +75,15 @@ public Map explain(List queryTopDocs for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - normalizedScores.computeIfAbsent( - new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), - k -> new ArrayList<>() - ).add(normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j))); - sourceScores.computeIfAbsent( - new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), - k -> new ArrayList<>() - ).add(scoreDoc.score); + DocIdAtQueryPhase docIdAtQueryPhase = new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); + float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); + normalizedScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(scoreDoc.score); + sourceScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(scoreDoc.score); + scoreDoc.score = normalizedScore; } } } - Map explain = sourceScores.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { - List srcScores = entry.getValue(); - List normScores = normalizedScores.get(entry.getKey()); - return "source scores " + srcScores + " normalized scores " + normScores; - })); - return explain; + return getDocIdAtQueryPhaseStringMap(normalizedScores, sourceScores); } private List getL2Norm(final List queryTopDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index ba7b8b20a..d22cd4d8e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -11,7 +11,6 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -21,12 +20,15 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; +import org.opensearch.neuralsearch.processor.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.util.ExplainUtils.getDocIdAtQueryPhaseStringMap; /** * Abstracts normalization of scores based on min-max method */ @ToString(onlyExplicitlyIncluded = true) -public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique { +public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "min_max"; private static final float MIN_SCORE = 0.001f; @@ -71,12 +73,7 @@ public void normalize(final List queryTopDocs) { @Override public String describe() { - return String.format( - Locale.ROOT, - "normalization technique %s [%s]", - TECHNIQUE_NAME, - "score = (score - min_score)/(max_score - min_score)" - ); + return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); } @Override @@ -106,24 +103,16 @@ public Map explain(List queryTopDocs for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - normalizedScores.computeIfAbsent( - new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), - k -> new ArrayList<>() - ).add(normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j])); - sourceScores.computeIfAbsent( - new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()), - k -> new ArrayList<>() - ).add(scoreDoc.score); + DocIdAtQueryPhase docIdAtQueryPhase = new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); + float normalizedScore = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); + normalizedScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(normalizedScore); + sourceScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(scoreDoc.score); + scoreDoc.score = normalizedScore; } } } - Map explain = sourceScores.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { - List srcScores = entry.getValue(); - List normScores = normalizedScores.get(entry.getKey()); - return "source scores " + srcScores + " normalized scores " + normScores; - })); - return explain; + return getDocIdAtQueryPhaseStringMap(normalizedScores, sourceScores); } private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index 642bf284b..0b784c678 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -5,10 +5,8 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.List; -import java.util.Map; import org.opensearch.neuralsearch.processor.CompoundTopDocs; -import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; /** * Abstracts normalization of scores in query search results. @@ -20,10 +18,4 @@ public interface ScoreNormalizationTechnique { * @param queryTopDocs original query results from multiple shards and multiple sub-queries */ void normalize(final List queryTopDocs); - - default String describe() { - return "score normalization technique"; - } - - Map explain(final List queryTopDocs); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index fbfba2b66..aaa65b4e1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -10,6 +10,7 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; +import org.opensearch.neuralsearch.processor.ExplainableTechnique; public class ScoreNormalizer { @@ -30,7 +31,7 @@ private boolean canQueryResultsBeNormalized(final List queryTop public Map explain( final List queryTopDocs, - final ScoreNormalizationTechnique scoreNormalizationTechnique + final ExplainableTechnique scoreNormalizationTechnique ) { if (canQueryResultsBeNormalized(queryTopDocs)) { return scoreNormalizationTechnique.explain(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java new file mode 100644 index 000000000..914457cc2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.util; + +import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Utility class for explain functionality + */ +public class ExplainUtils { + + /** + * Creates map of DocIdAtQueryPhase to String containing source and normalized scores + * @param normalizedScores map of DocIdAtQueryPhase to normalized scores + * @param sourceScores map of DocIdAtQueryPhase to source scores + * @return map of DocIdAtQueryPhase to String containing source and normalized scores + */ + public static Map getDocIdAtQueryPhaseStringMap( + final Map> normalizedScores, + final Map> sourceScores + ) { + Map explain = sourceScores.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { + List srcScores = entry.getValue(); + List normScores = normalizedScores.get(entry.getKey()); + return String.format("source scores %s normalized scores %s", srcScores, normScores); + })); + return explain; + } + + /** + * Creates a string describing the combination technique and its parameters + * @param techniqueName the name of the combination technique + * @param weights the weights used in the combination technique + * @return a string describing the combination technique and its parameters + */ + public static String describeCombinationTechnique(final String techniqueName, final List weights) { + return String.format(Locale.ROOT, "combination technique [%s] with optional parameters [%s]", techniqueName, weights); + } +} From 64f69e870f00baaa383711ea2cf2cc663897a788 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 28 Oct 2024 13:26:19 -0700 Subject: [PATCH 03/11] Initial version with signle sorted list of explains Signed-off-by: Martin Gaievski --- .../NormalizationProcessorWorkflow.java | 55 +++++--- .../processor/ProcessorExplainPublisher.java | 126 ++++++++++------- .../neuralsearch/processor/SearchShard.java | 6 + ...ithmeticMeanScoreCombinationTechnique.java | 4 +- ...eometricMeanScoreCombinationTechnique.java | 4 +- ...HarmonicMeanScoreCombinationTechnique.java | 4 +- .../processor/combination/ScoreCombiner.java | 87 ++++++++---- .../explain/CombinedExplainDetails.java | 17 +++ .../DocIdAtSearchShard.java} | 6 +- .../processor/explain/ExplainDetails.java | 18 +++ .../processor/explain/ExplainUtils.java | 130 ++++++++++++++++++ .../{ => explain}/ExplainableTechnique.java | 6 +- .../{ => explain}/ProcessorExplainDto.java | 7 +- .../L2ScoreNormalizationTechnique.java | 22 +-- .../MinMaxScoreNormalizationTechnique.java | 22 +-- .../normalization/ScoreNormalizer.java | 7 +- .../processor/util/ExplainUtils.java | 46 ------- 17 files changed, 389 insertions(+), 178 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java rename src/main/java/org/opensearch/neuralsearch/processor/{DocIdAtQueryPhase.java => explain/DocIdAtSearchShard.java} (62%) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java rename src/main/java/org/opensearch/neuralsearch/processor/{ => explain}/ExplainableTechnique.java (77%) rename src/main/java/org/opensearch/neuralsearch/processor/{ => explain}/ProcessorExplainDto.java (61%) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 96f6bf2aa..eac36c7fa 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.processor; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -23,6 +24,11 @@ import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -36,6 +42,7 @@ import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; /** @@ -106,19 +113,42 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< if (!request.isExplain()) { return; } - Explanation describedTechniqueForExplain = describeTechniqueForExplain(request); + Explanation topLevelExplanationForTechniques = topLevelExpalantionForCombinedScore( + (ExplainableTechnique) request.getNormalizationTechnique(), + (ExplainableTechnique) request.getCombinationTechnique() + ); // build final result object with all explain related information if (Objects.nonNull(request.getPipelineProcessingContext())) { - Map explainedNormalization = scoreNormalizer.explain( + + Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs); + + Map normalizationExplain = scoreNormalizer.explain( queryTopDocs, (ExplainableTechnique) request.getNormalizationTechnique() ); + Map> combinationExplain = scoreCombiner.explain( + queryTopDocs, + request.getCombinationTechnique(), + sortForQuery + ); + Map> combinedExplain = new HashMap<>(); + + combinationExplain.forEach((searchShard, explainDetails) -> { + for (ExplainDetails explainDetail : explainDetails) { + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard); + ExplainDetails normalizedExplainDetails = normalizationExplain.get(docIdAtSearchShard); + CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder() + .normalizationExplain(normalizedExplainDetails) + .combinationExplain(explainDetail) + .build(); + combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails); + } + }); ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() - .explanation(describedTechniqueForExplain) - .normalizedScoresByDocId(explainedNormalization) - .combinedScoresByDocId(scoreCombiner.explain(queryTopDocs, request.getCombinationTechnique())) + .explanation(topLevelExplanationForTechniques) + .explainDetailsByShard(combinedExplain) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); @@ -127,21 +157,6 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< } - private static Explanation describeTechniqueForExplain(NormalizationProcessorWorkflowExecuteRequest request) { - // general description of techniques - ExplainableTechnique explainableCombinationTechnique = (ExplainableTechnique) request.getCombinationTechnique(); - ExplainableTechnique explainableNormalizationTechnique = (ExplainableTechnique) request.getNormalizationTechnique(); - String explanationDetailsMessage = String.format( - Locale.ROOT, - "%s, %s", - explainableNormalizationTechnique.describe(), - explainableCombinationTechnique.describe() - ); - - Explanation explanation = Explanation.match(0.0f, explanationDetailsMessage); - return explanation; - } - /** * Getting list of CompoundTopDocs from list of QuerySearchResult. Each CompoundTopDocs is for individual shard * @param querySearchResults collection of QuerySearchResult for all shards diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java index b6381f7ce..ef4874485 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java @@ -9,12 +9,18 @@ import org.apache.lucene.search.Explanation; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.SearchResponseProcessor; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Objects; import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; @@ -36,59 +42,87 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { - if (Objects.nonNull(requestContext.getAttribute(PROCESSOR_EXPLAIN))) { - ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(PROCESSOR_EXPLAIN); - Explanation explanation = processorExplainDto.getExplanation(); - SearchHits searchHits = response.getHits(); - SearchHit[] searchHitsArray = searchHits.getHits(); - for (SearchHit searchHit : searchHitsArray) { - SearchShardTarget searchShardTarget = searchHit.getShard(); - DocIdAtQueryPhase docIdAtQueryPhase = new DocIdAtQueryPhase( - searchHit.docId(), - new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()) - ); - Explanation normalizedExplanation = Explanation.match( - 0.0f, - processorExplainDto.getNormalizedScoresByDocId().get(docIdAtQueryPhase) - ); - Explanation combinedExplanation = Explanation.match( - 0.0f, - processorExplainDto.getCombinedScoresByDocId().get(docIdAtQueryPhase) - ); - Explanation finalExplanation = Explanation.match( - searchHit.getScore(), - "combined explanation from processor and query: ", - explanation, - normalizedExplanation, - combinedExplanation, - searchHit.getExplanation() - ); - searchHit.explanation(finalExplanation); - } - // delete processor explain data to avoid double processing - // requestContext.setAttribute(PROCESSOR_EXPLAIN, null); + if (Objects.isNull(requestContext.getAttribute(PROCESSOR_EXPLAIN))) { + return response; } + ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(PROCESSOR_EXPLAIN); + Explanation processorExplanation = processorExplainDto.getExplanation(); + if (Objects.isNull(processorExplanation)) { + return response; + } + SearchHits searchHits = response.getHits(); + SearchHit[] searchHitsArray = searchHits.getHits(); + // create a map of searchShard and list of indexes of search hit objects in search hits array + // the list will keep original order of sorting as per final search results + Map> searchHitsByShard = new HashMap<>(); + Map explainsByShardCount = new HashMap<>(); + for (int i = 0; i < searchHitsArray.length; i++) { + SearchHit searchHit = searchHitsArray[i]; + SearchShardTarget searchShardTarget = searchHit.getShard(); + SearchShard searchShard = SearchShard.create(searchShardTarget); + searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); + explainsByShardCount.putIfAbsent(searchShard, -1); + } + Map> combinedExplainDetails = processorExplainDto.getExplainDetailsByShard(); + for (int i = 0; i < searchHitsArray.length; i++) { + SearchHit searchHit = searchHitsArray[i]; + SearchShard searchShard = SearchShard.create(searchHit.getShard()); + int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; + CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); + // searchHit.explanation(getExplanation(searchHit, processorExplainDto, processorExplanation)); + Explanation normalizedExplanation = Explanation.match( + combinedExplainDetail.getNormalizationExplain().value(), + combinedExplainDetail.getNormalizationExplain().description() + ); + Explanation combinedExplanation = Explanation.match( + combinedExplainDetail.getCombinationExplain().value(), + combinedExplainDetail.getCombinationExplain().description() + ); + + Explanation finalExplanation = Explanation.match( + searchHit.getScore(), + processorExplanation.getDescription(), + normalizedExplanation, + combinedExplanation, + searchHit.getExplanation() + ); + searchHit.explanation(finalExplanation); + explainsByShardCount.put(searchShard, explanationIndexByShard); + } return response; } + /*private static Explanation getExplanation( + SearchHit searchHit, + ProcessorExplainDto processorExplainDto, + Explanation generalProcessorLevelExplanation + ) { + SearchShardTarget searchShardTarget = searchHit.getShard(); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard( + searchHit.docId(), + new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()) + ); + SearchShard searchShard = new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()); + ExplainDetails normalizationExplainDetails = processorExplainDto.getNormalizedScoresByDocId().get(docIdAtSearchShard); + Explanation normalizedExplanation = Explanation.match( + normalizationExplainDetails.value(), + normalizationExplainDetails.description() + ); + List combinedExplainDetails = processorExplainDto.getCombinedScoresByShard().get(searchShard); + Explanation combinedExplanation = Explanation.match(combinedExplainDetails.value(), combinedExplainDetails.description()); + Explanation finalExplanation = Explanation.match( + searchHit.getScore(), + generalProcessorLevelExplanation.getDescription(), + normalizedExplanation, + combinedExplanation, + searchHit.getExplanation() + ); + return finalExplanation; + }*/ + @Override public String getType() { return TYPE; } - - @Override - public String getTag() { - return tag; - } - - @Override - public String getDescription() { - return description; - } - - @Override - public boolean isIgnoreFailure() { - return ignoreFailure; - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java index 57c893c90..61f9d0be5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -4,5 +4,11 @@ */ package org.opensearch.neuralsearch.processor; +import org.opensearch.search.SearchShardTarget; + public record SearchShard(String index, int shardId, String nodeId) { + + public static SearchShard create(SearchShardTarget searchShardTarget) { + return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 7ff6bea18..539eaa081 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -9,9 +9,9 @@ import java.util.Set; import lombok.ToString; -import org.opensearch.neuralsearch.processor.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.util.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on arithmetic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index a0cd49136..38f98c018 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -9,9 +9,9 @@ import java.util.Set; import lombok.ToString; -import org.opensearch.neuralsearch.processor.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.util.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on geometrical mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 7aabf0d61..ddfbb5223 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -9,9 +9,9 @@ import java.util.Set; import lombok.ToString; -import org.opensearch.neuralsearch.processor.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.util.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on harmonic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 9ab74ea32..644600b21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -27,7 +27,8 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.extern.log4j.Log4j2; -import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; /** * Abstracts combination of scores in query search results. @@ -98,14 +99,9 @@ private void combineShardScores( // - sort documents by scores and take first "max number" of docs // create a collection of doc ids that are sorted by their combined scores - Collection sortedDocsIds; - if (sort != null) { - sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort); - } else { - sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); - } + Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - // - update query search results with normalized scores + // - update query search results with combined scores updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs, topDocsPerSubQuery, @@ -323,40 +319,73 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon return new TotalHits(maxHits, totalHits); } - public Map explain( + public Map> explain( final List queryTopDocs, - ScoreCombinationTechnique combinationTechnique + ScoreCombinationTechnique combinationTechnique, + Sort sort ) { - Map explain = new HashMap<>(); - queryTopDocs.forEach(compoundQueryTopDocs -> explainByShard(combinationTechnique, compoundQueryTopDocs, explain)); - return explain; + // In case of duplicate keys, keep the first value + HashMap> map = new HashMap<>(); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + for (Map.Entry> docIdAtSearchShardExplainDetailsEntry : explainByShard( + combinationTechnique, + compoundQueryTopDocs, + sort + ).entrySet()) { + map.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); + } + } + return map; } - private void explainByShard( + private Map> explainByShard( final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs, - Map explain + Sort sort ) { if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { - return; + return Map.of(); } - List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - // - create map of normalized scores results returned from the single shard - Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery); - - // - create map of combined scores per doc id + Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs()); + Map> explainsForShard = new HashMap<>(); Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); + SearchShard searchShard = compoundQueryTopDocs.getSearchShard(); + explainsForShard.put(searchShard, new ArrayList<>()); + for (Integer docId : sortedDocsIds) { + float combinedScore = combinedNormalizedScoresByDocId.get(docId); + explainsForShard.get(searchShard) + .add( + new ExplainDetails( + combinedScore, + String.format( + "source scores: %s, combined score %s", + Arrays.toString(normalizedScoresPerDoc.get(docId)), + combinedScore + ), + docId + ) + ); + } - normalizedScoresPerDoc.entrySet().stream().forEach(entry -> { - float[] srcScores = entry.getValue(); - float combinedScore = combinedNormalizedScoresByDocId.get(entry.getKey()); - explain.put( - new DocIdAtQueryPhase(entry.getKey(), compoundQueryTopDocs.getSearchShard()), - String.format("source scores [%s], combined score [%s]", Arrays.toString(srcScores), combinedScore) - ); - }); + return explainsForShard; + } + + private Collection getSortedDocsIds( + CompoundTopDocs compoundQueryTopDocs, + Sort sort, + Map combinedNormalizedScoresByDocId + ) { + Collection sortedDocsIds; + if (sort != null) { + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort); + } else { + sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + } + return sortedDocsIds; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java new file mode 100644 index 000000000..8793f5d53 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +@AllArgsConstructor +@Builder +@Getter +public class CombinedExplainDetails { + private ExplainDetails normalizationExplain; + private ExplainDetails combinationExplain; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java similarity index 62% rename from src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java index 95d4fd7d9..70da8b73c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DocIdAtQueryPhase.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java @@ -2,7 +2,9 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor; +package org.opensearch.neuralsearch.processor.explain; + +import org.opensearch.neuralsearch.processor.SearchShard; /** * Data class to store docId and search shard for a query. @@ -10,5 +12,5 @@ * @param docId * @param searchShard */ -public record DocIdAtQueryPhase(int docId, SearchShard searchShard) { +public record DocIdAtSearchShard(int docId, SearchShard searchShard) { } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java new file mode 100644 index 000000000..0bbdea27f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +/** + * Data class to store value and description for explain details. + * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. + * @param value + * @param description + */ +public record ExplainDetails(float value, String description, int docId) { + + public ExplainDetails(float value, String description) { + this(value, description, -1); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java new file mode 100644 index 000000000..d31a3e9bb --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.apache.lucene.search.Explanation; +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Utility class for explain functionality + */ +public class ExplainUtils { + + /** + * Creates map of DocIdAtQueryPhase to String containing source and normalized scores + * @param normalizedScores map of DocIdAtQueryPhase to normalized scores + * @param sourceScores map of DocIdAtQueryPhase to source scores + * @return map of DocIdAtQueryPhase to String containing source and normalized scores + */ + public static Map getDocIdAtQueryForNormalization( + final Map> normalizedScores, + final Map> sourceScores + ) { + Map explain = sourceScores.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> { + List srcScores = entry.getValue(); + List normScores = normalizedScores.get(entry.getKey()); + return new ExplainDetails( + normScores.stream().reduce(0.0f, Float::max), + String.format("source scores: %s, normalized scores: %s", srcScores, normScores) + ); + })); + return explain; + } + + /** + * Creates map of DocIdAtQueryPhase to String containing source scores and combined score + * @param scoreCombinationTechnique the combination technique used + * @param normalizedScoresPerDoc map of DocIdAtQueryPhase to normalized scores + * @param searchShard the search shard + * @return map of DocIdAtQueryPhase to String containing source scores and combined score + */ + public static Map getDocIdAtQueryForCombination( + ScoreCombinationTechnique scoreCombinationTechnique, + Map normalizedScoresPerDoc, + SearchShard searchShard + ) { + Map explain = new HashMap<>(); + // - create map of combined scores per doc id + Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + normalizedScoresPerDoc.forEach((key, srcScores) -> { + float combinedScore = combinedNormalizedScoresByDocId.get(key); + explain.put( + new DocIdAtSearchShard(key, searchShard), + new ExplainDetails( + combinedScore, + String.format("source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) + ) + ); + }); + return explain; + } + + /* public static Map> getExplainsByShardForCombination( + ScoreCombinationTechnique scoreCombinationTechnique, + Map normalizedScoresPerDoc, + SearchShard searchShard + ) { + Map explain = new HashMap<>(); + // - create map of combined scores per doc id + Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + normalizedScoresPerDoc.forEach((key, srcScores) -> { + float combinedScore = combinedNormalizedScoresByDocId.get(key); + explain.put( + new DocIdAtSearchShard(key, searchShard), + new ExplainDetails( + combinedScore, + String.format("source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) + ) + ); + }); + return explain; + } + */ + + /** + * Creates a string describing the combination technique and its parameters + * @param techniqueName the name of the combination technique + * @param weights the weights used in the combination technique + * @return a string describing the combination technique and its parameters + */ + public static String describeCombinationTechnique(final String techniqueName, final List weights) { + return String.format(Locale.ROOT, "combination technique [%s] with optional parameters [%s]", techniqueName, weights); + } + + /** + * Creates an Explanation object for the top-level explanation of the combined score + * @param explainableNormalizationTechnique the normalization technique used + * @param explainableCombinationTechnique the combination technique used + * @return an Explanation object for the top-level explanation of the combined score + */ + public static Explanation topLevelExpalantionForCombinedScore( + final ExplainableTechnique explainableNormalizationTechnique, + final ExplainableTechnique explainableCombinationTechnique + ) { + String explanationDetailsMessage = String.format( + Locale.ROOT, + "combine score with techniques: %s, %s", + explainableNormalizationTechnique.describe(), + explainableCombinationTechnique.describe() + ); + + return Explanation.match(0.0f, explanationDetailsMessage); + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java similarity index 77% rename from src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java index b2da43fc9..6f8dfcf1e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplainableTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java @@ -2,7 +2,9 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor; +package org.opensearch.neuralsearch.processor.explain; + +import org.opensearch.neuralsearch.processor.CompoundTopDocs; import java.util.List; import java.util.Map; @@ -26,7 +28,7 @@ default String describe() { * @param queryTopDocs collection of CompoundTopDocs for each shard result * @return map of document per shard and corresponding explanation object */ - default Map explain(final List queryTopDocs) { + default Map explain(final List queryTopDocs) { return Map.of(); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java similarity index 61% rename from src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java index 1255779f6..539ba1a94 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java @@ -2,13 +2,15 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor; +package org.opensearch.neuralsearch.processor.explain; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; import org.apache.lucene.search.Explanation; +import org.opensearch.neuralsearch.processor.SearchShard; +import java.util.List; import java.util.Map; @AllArgsConstructor @@ -16,6 +18,5 @@ @Getter public class ProcessorExplainDto { Explanation explanation; - Map normalizedScoresByDocId; - Map combinedScoresByDocId; + Map> explainDetailsByShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 1ed3bf88a..e1eb6c56a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -16,10 +16,11 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.ToString; -import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; -import org.opensearch.neuralsearch.processor.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.util.ExplainUtils.getDocIdAtQueryPhaseStringMap; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on L2 method @@ -57,14 +58,15 @@ public void normalize(final List queryTopDocs) { } } + @Override public String describe() { return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); } @Override - public Map explain(List queryTopDocs) { - Map> normalizedScores = new HashMap<>(); - Map> sourceScores = new HashMap<>(); + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + Map> sourceScores = new HashMap<>(); List normsPerSubquery = getL2Norm(queryTopDocs); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { @@ -75,15 +77,15 @@ public Map explain(List queryTopDocs for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - DocIdAtQueryPhase docIdAtQueryPhase = new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); - normalizedScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(scoreDoc.score); - sourceScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(scoreDoc.score); + normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore); + sourceScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(scoreDoc.score); scoreDoc.score = normalizedScore; } } } - return getDocIdAtQueryPhaseStringMap(normalizedScores, sourceScores); + return getDocIdAtQueryForNormalization(normalizedScores, sourceScores); } private List getL2Norm(final List queryTopDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index d22cd4d8e..8e0f0f586 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -19,10 +19,11 @@ import com.google.common.primitives.Floats; import lombok.ToString; -import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; -import org.opensearch.neuralsearch.processor.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.util.ExplainUtils.getDocIdAtQueryPhaseStringMap; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on min-max method @@ -77,9 +78,9 @@ public String describe() { } @Override - public Map explain(List queryTopDocs) { - Map> normalizedScores = new HashMap<>(); - Map> sourceScores = new HashMap<>(); + public Map explain(final List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + Map> sourceScores = new HashMap<>(); int numOfSubqueries = queryTopDocs.stream() .filter(Objects::nonNull) @@ -94,7 +95,6 @@ public Map explain(List queryTopDocs // get max scores for each sub query float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); - // do normalization using actual score and min and max scores for corresponding sub query for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -103,16 +103,16 @@ public Map explain(List queryTopDocs for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - DocIdAtQueryPhase docIdAtQueryPhase = new DocIdAtQueryPhase(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); float normalizedScore = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); - normalizedScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(normalizedScore); - sourceScores.computeIfAbsent(docIdAtQueryPhase, k -> new ArrayList<>()).add(scoreDoc.score); + normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore); + sourceScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(scoreDoc.score); scoreDoc.score = normalizedScore; } } } - return getDocIdAtQueryPhaseStringMap(normalizedScores, sourceScores); + return getDocIdAtQueryForNormalization(normalizedScores, sourceScores); } private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index aaa65b4e1..2dcf5f768 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -9,8 +9,9 @@ import java.util.Objects; import org.opensearch.neuralsearch.processor.CompoundTopDocs; -import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; -import org.opensearch.neuralsearch.processor.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; public class ScoreNormalizer { @@ -29,7 +30,7 @@ private boolean canQueryResultsBeNormalized(final List queryTop return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0); } - public Map explain( + public Map explain( final List queryTopDocs, final ExplainableTechnique scoreNormalizationTechnique ) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java deleted file mode 100644 index 914457cc2..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/util/ExplainUtils.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.processor.util; - -import org.opensearch.neuralsearch.processor.DocIdAtQueryPhase; - -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Utility class for explain functionality - */ -public class ExplainUtils { - - /** - * Creates map of DocIdAtQueryPhase to String containing source and normalized scores - * @param normalizedScores map of DocIdAtQueryPhase to normalized scores - * @param sourceScores map of DocIdAtQueryPhase to source scores - * @return map of DocIdAtQueryPhase to String containing source and normalized scores - */ - public static Map getDocIdAtQueryPhaseStringMap( - final Map> normalizedScores, - final Map> sourceScores - ) { - Map explain = sourceScores.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { - List srcScores = entry.getValue(); - List normScores = normalizedScores.get(entry.getKey()); - return String.format("source scores %s normalized scores %s", srcScores, normScores); - })); - return explain; - } - - /** - * Creates a string describing the combination technique and its parameters - * @param techniqueName the name of the combination technique - * @param weights the weights used in the combination technique - * @return a string describing the combination technique and its parameters - */ - public static String describeCombinationTechnique(final String techniqueName, final List weights) { - return String.format(Locale.ROOT, "combination technique [%s] with optional parameters [%s]", techniqueName, weights); - } -} From 7f384da5c79b41d9237c093f0322cb650dd88d33 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 28 Oct 2024 18:07:29 -0700 Subject: [PATCH 04/11] Adding basic tests, minor refactoring Signed-off-by: Martin Gaievski --- .../NormalizationProcessorWorkflow.java | 2 +- .../processor/ProcessorExplainPublisher.java | 119 +++---- .../processor/combination/ScoreCombiner.java | 2 + .../processor/explain/ExplainUtils.java | 4 +- .../explain/ProcessorExplainDto.java | 8 +- .../ProcessorExplainPublisherTests.java | 57 ++++ .../neuralsearch/query/HybridQueryIT.java | 295 ++++++++++++++++++ .../neuralsearch/BaseNeuralSearchIT.java | 27 +- 8 files changed, 436 insertions(+), 78 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index eac36c7fa..084a94bf7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -148,7 +148,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() .explanation(topLevelExplanationForTechniques) - .explainDetailsByShard(combinedExplain) + .explainPayload(Map.of(ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java index ef4874485..a138e5d6d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisher.java @@ -24,6 +24,7 @@ import java.util.Objects; import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; +import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR; @Getter @AllArgsConstructor @@ -42,85 +43,63 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { - if (Objects.isNull(requestContext.getAttribute(PROCESSOR_EXPLAIN))) { + if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(PROCESSOR_EXPLAIN)))) { return response; } ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(PROCESSOR_EXPLAIN); - Explanation processorExplanation = processorExplainDto.getExplanation(); - if (Objects.isNull(processorExplanation)) { - return response; - } - SearchHits searchHits = response.getHits(); - SearchHit[] searchHitsArray = searchHits.getHits(); - // create a map of searchShard and list of indexes of search hit objects in search hits array - // the list will keep original order of sorting as per final search results - Map> searchHitsByShard = new HashMap<>(); - Map explainsByShardCount = new HashMap<>(); - for (int i = 0; i < searchHitsArray.length; i++) { - SearchHit searchHit = searchHitsArray[i]; - SearchShardTarget searchShardTarget = searchHit.getShard(); - SearchShard searchShard = SearchShard.create(searchShardTarget); - searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); - explainsByShardCount.putIfAbsent(searchShard, -1); - } - Map> combinedExplainDetails = processorExplainDto.getExplainDetailsByShard(); + Map explainPayload = processorExplainDto.getExplainPayload(); + if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { + Explanation processorExplanation = processorExplainDto.getExplanation(); + if (Objects.isNull(processorExplanation)) { + return response; + } + SearchHits searchHits = response.getHits(); + SearchHit[] searchHitsArray = searchHits.getHits(); + // create a map of searchShard and list of indexes of search hit objects in search hits array + // the list will keep original order of sorting as per final search results + Map> searchHitsByShard = new HashMap<>(); + Map explainsByShardCount = new HashMap<>(); + for (int i = 0; i < searchHitsArray.length; i++) { + SearchHit searchHit = searchHitsArray[i]; + SearchShardTarget searchShardTarget = searchHit.getShard(); + SearchShard searchShard = SearchShard.create(searchShardTarget); + searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); + explainsByShardCount.putIfAbsent(searchShard, -1); + } + if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map) { + @SuppressWarnings("unchecked") + Map> combinedExplainDetails = (Map< + SearchShard, + List>) explainPayload.get(NORMALIZATION_PROCESSOR); - for (int i = 0; i < searchHitsArray.length; i++) { - SearchHit searchHit = searchHitsArray[i]; - SearchShard searchShard = SearchShard.create(searchHit.getShard()); - int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; - CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); - // searchHit.explanation(getExplanation(searchHit, processorExplainDto, processorExplanation)); - Explanation normalizedExplanation = Explanation.match( - combinedExplainDetail.getNormalizationExplain().value(), - combinedExplainDetail.getNormalizationExplain().description() - ); - Explanation combinedExplanation = Explanation.match( - combinedExplainDetail.getCombinationExplain().value(), - combinedExplainDetail.getCombinationExplain().description() - ); + for (SearchHit searchHit : searchHitsArray) { + SearchShard searchShard = SearchShard.create(searchHit.getShard()); + int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; + CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); + Explanation normalizedExplanation = Explanation.match( + combinedExplainDetail.getNormalizationExplain().value(), + combinedExplainDetail.getNormalizationExplain().description() + ); + Explanation combinedExplanation = Explanation.match( + combinedExplainDetail.getCombinationExplain().value(), + combinedExplainDetail.getCombinationExplain().description() + ); - Explanation finalExplanation = Explanation.match( - searchHit.getScore(), - processorExplanation.getDescription(), - normalizedExplanation, - combinedExplanation, - searchHit.getExplanation() - ); - searchHit.explanation(finalExplanation); - explainsByShardCount.put(searchShard, explanationIndexByShard); + Explanation finalExplanation = Explanation.match( + searchHit.getScore(), + processorExplanation.getDescription(), + normalizedExplanation, + combinedExplanation, + searchHit.getExplanation() + ); + searchHit.explanation(finalExplanation); + explainsByShardCount.put(searchShard, explanationIndexByShard); + } + } } return response; } - /*private static Explanation getExplanation( - SearchHit searchHit, - ProcessorExplainDto processorExplainDto, - Explanation generalProcessorLevelExplanation - ) { - SearchShardTarget searchShardTarget = searchHit.getShard(); - DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard( - searchHit.docId(), - new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()) - ); - SearchShard searchShard = new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()); - ExplainDetails normalizationExplainDetails = processorExplainDto.getNormalizedScoresByDocId().get(docIdAtSearchShard); - Explanation normalizedExplanation = Explanation.match( - normalizationExplainDetails.value(), - normalizationExplainDetails.description() - ); - List combinedExplainDetails = processorExplainDto.getCombinedScoresByShard().get(searchShard); - Explanation combinedExplanation = Explanation.match(combinedExplainDetails.value(), combinedExplainDetails.description()); - Explanation finalExplanation = Explanation.match( - searchHit.getScore(), - generalProcessorLevelExplanation.getDescription(), - normalizedExplanation, - combinedExplanation, - searchHit.getExplanation() - ); - return finalExplanation; - }*/ - @Override public String getType() { return TYPE; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 644600b21..7314c3383 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -9,6 +9,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.Objects; @@ -362,6 +363,7 @@ private Map> explainByShard( new ExplainDetails( combinedScore, String.format( + Locale.ROOT, "source scores: %s, combined score %s", Arrays.toString(normalizedScoresPerDoc.get(docId)), combinedScore diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java index d31a3e9bb..0de386ca7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java @@ -37,7 +37,7 @@ public static Map getDocIdAtQueryForNormaliz List normScores = normalizedScores.get(entry.getKey()); return new ExplainDetails( normScores.stream().reduce(0.0f, Float::max), - String.format("source scores: %s, normalized scores: %s", srcScores, normScores) + String.format(Locale.ROOT, "source scores: %s, normalized scores: %s", srcScores, normScores) ); })); return explain; @@ -66,7 +66,7 @@ public static Map getDocIdAtQueryForCombinat new DocIdAtSearchShard(key, searchShard), new ExplainDetails( combinedScore, - String.format("source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) + String.format(Locale.ROOT, "source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) ) ); }); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java index 539ba1a94..ed29cd6e3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java @@ -8,9 +8,7 @@ import lombok.Builder; import lombok.Getter; import org.apache.lucene.search.Explanation; -import org.opensearch.neuralsearch.processor.SearchShard; -import java.util.List; import java.util.Map; @AllArgsConstructor @@ -18,5 +16,9 @@ @Getter public class ProcessorExplainDto { Explanation explanation; - Map> explainDetailsByShard; + Map explainPayload; + + public enum ExplanationType { + NORMALIZATION_PROCESSOR + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java new file mode 100644 index 000000000..0f5abfec4 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.Mockito.mock; + +public class ProcessorExplainPublisherTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { + ProcessorExplainPublisher processorExplainPublisher = new ProcessorExplainPublisher(DESCRIPTION, PROCESSOR_TAG, false); + + assertEquals(DESCRIPTION, processorExplainPublisher.getDescription()); + assertEquals(PROCESSOR_TAG, processorExplainPublisher.getTag()); + assertFalse(processorExplainPublisher.isIgnoreFailure()); + } + + @SneakyThrows + public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { + ProcessorExplainPublisher processorExplainPublisher = new ProcessorExplainPublisher(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = new SearchResponse( + null, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + + SearchResponse processedResponse = processorExplainPublisher.processResponse(searchRequest, searchResponse); + assertEquals(searchResponse, processedResponse); + + SearchResponse processedResponse2 = processorExplainPublisher.processResponse(searchRequest, searchResponse, null); + assertEquals(searchResponse, processedResponse2); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + SearchResponse processedResponse3 = processorExplainPublisher.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + assertEquals(searchResponse, processedResponse3); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 610e08dd0..e24a4da5b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -12,9 +12,13 @@ import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -24,6 +28,7 @@ import org.apache.commons.lang.RandomStringUtils; import org.apache.commons.lang.math.RandomUtils; +import org.apache.commons.lang3.Range; import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.ResponseException; @@ -34,6 +39,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.primitives.Floats; @@ -82,6 +88,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; protected static final int SINGLE_SHARD = 1; protected static final int MULTIPLE_SHARDS = 3; + public static final String NORMALIZATION_TECHNIQUE_L2 = "l2"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -833,6 +840,294 @@ public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { } } + @SneakyThrows + public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedGeneralCombineScoreDescription = + "combine score with techniques: normalization technique [min_max], combination technique [arithmetic_mean] with optional parameters [[]]"; + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); + List> hit1Details = (List>) explanationForHit1.get("details"); + assertEquals(3, hit1Details.size()); + Map hit1DetailsForHit1 = hit1Details.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + assertTrue(((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[1\\.0\\]")); + assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertEquals(0.5, hit1DetailsForHit2.get("value")); + assertEquals("source scores: [0.0, 1.0], combined score 0.5", hit1DetailsForHit2.get("description")); + assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); + + Map hit1DetailsForHit3 = hit1Details.get(2); + double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); + assertTrue(actualHit1ScoreHit3 > 0.0); + assertEquals("combination of:", hit1DetailsForHit3.get("description")); + assertEquals(1, ((List) hit1DetailsForHit3.get("details")).size()); + + Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); + assertEquals(actualHit1ScoreHit3, ((double) hit1SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("sum of:", hit1SubDetailsForHit3.get("description")); + assertEquals(1, ((List) hit1SubDetailsForHit3.get("details")).size()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = (Map) searchHit2.get("_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit2.get("description")); + List> hit2Details = (List>) explanationForHit2.get("details"); + assertEquals(3, hit2Details.size()); + Map hit2DetailsForHit1 = hit2Details.get(0); + assertEquals(1.0, hit2DetailsForHit1.get("value")); + assertTrue(((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[1\\.0\\]")); + assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertEquals(0.5, hit2DetailsForHit2.get("value")); + assertEquals("source scores: [1.0, 0.0], combined score 0.5", hit2DetailsForHit2.get("description")); + assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); + + Map hit2DetailsForHit3 = hit2Details.get(2); + double actualHit2ScoreHit3 = ((double) hit2DetailsForHit3.get("value")); + assertTrue(actualHit2ScoreHit3 > 0.0); + assertEquals("combination of:", hit2DetailsForHit3.get("description")); + assertEquals(1, ((List) hit2DetailsForHit3.get("details")).size()); + + Map hit2SubDetailsForHit3 = (Map) ((List) hit2DetailsForHit3.get("details")).get(0); + assertEquals(actualHit2ScoreHit3, ((double) hit2SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", hit2SubDetailsForHit3.get("description")); + assertEquals(1, ((List) hit2SubDetailsForHit3.get("details")).size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = (Map) searchHit3.get("_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit3.get("description")); + List> hit3Details = (List>) explanationForHit3.get("details"); + assertEquals(3, hit3Details.size()); + Map hit3DetailsForHit1 = hit3Details.get(0); + assertEquals(0.001, hit3DetailsForHit1.get("value")); + assertTrue( + ((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[0\\.001\\]") + ); + assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); + + Map hit3DetailsForHit2 = hit3Details.get(1); + assertEquals(0.0005, hit3DetailsForHit2.get("value")); + assertEquals("source scores: [0.0, 0.001], combined score 5.0E-4", hit3DetailsForHit2.get("description")); + assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); + + Map hit3DetailsForHit3 = hit3Details.get(2); + double actualHit3ScoreHit3 = ((double) hit3DetailsForHit3.get("value")); + assertTrue(actualHit3ScoreHit3 > 0.0); + assertEquals("combination of:", hit3DetailsForHit3.get("description")); + assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); + + Map hit3SubDetailsForHit3 = (Map) ((List) hit3DetailsForHit3.get("details")).get(0); + assertEquals(actualHit3ScoreHit3, ((double) hit3SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("sum of:", hit3SubDetailsForHit3.get("description")); + assertEquals(1, ((List) hit3SubDetailsForHit3.get("details")).size()); + } finally { + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_TECHNIQUE_L2, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), + true + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedGeneralCombineScoreDescription = + "combine score with techniques: normalization technique [l2], combination technique [arithmetic_mean] with optional parameters [" + + Arrays.toString(new float[] { 0.3f, 0.7f }) + + "]"; + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); + List> hit1Details = (List>) explanationForHit1.get("details"); + assertEquals(3, hit1Details.size()); + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > 0.5f); + assertTrue( + ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\], normalized scores: \\[.*, .*\\]") + ); + assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertEquals(actualMaxScore, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertTrue(((String) hit1DetailsForHit2.get("description")).matches("source scores: \\[.*, .*\\], combined score .*")); + assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); + + Map hit1DetailsForHit3 = hit1Details.get(2); + assertEquals(1.0, (double) hit1DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertTrue(((String) hit1DetailsForHit3.get("description")).matches("combination of:")); + assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = (Map) searchHit2.get("_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit2.get("description")); + List> hit2Details = (List>) explanationForHit2.get("details"); + assertEquals(3, hit2Details.size()); + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > 0.5f); + assertTrue( + ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\], normalized scores: \\[.*, .*\\]") + ); + assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit2DetailsForHit2.get("value"))); + assertTrue(((String) hit2DetailsForHit2.get("description")).matches("source scores: \\[.*, .*\\], combined score .*")); + assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); + + Map hit2DetailsForHit3 = hit2Details.get(2); + assertEquals(1.0, (double) hit2DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertTrue(((String) hit2DetailsForHit3.get("description")).matches("combination of:")); + assertEquals(2, ((List) hit2DetailsForHit3.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = (Map) searchHit3.get("_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit3.get("description")); + List> hit3Details = (List>) explanationForHit3.get("details"); + assertEquals(3, hit3Details.size()); + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > 0.5f); + assertTrue(((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[.*\\]")); + assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); + + Map hit3DetailsForHit2 = hit3Details.get(1); + assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit2.get("value"))); + assertTrue(((String) hit3DetailsForHit2.get("description")).matches("source scores: \\[0.0, .*\\], combined score .*")); + assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); + + Map hit3DetailsForHit3 = hit3Details.get(2); + assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit3.get("value"))); + assertTrue(((String) hit3DetailsForHit3.get("description")).matches("combination of:")); + assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = (Map) searchHit4.get("_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit4.get("description")); + List> hit4Details = (List>) explanationForHit4.get("details"); + assertEquals(3, hit4Details.size()); + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > 0.5f); + assertTrue(((String) hit4DetailsForHit1.get("description")).matches("source scores: \\[1.0\\], normalized scores: \\[.*\\]")); + assertEquals(0, ((List) hit4DetailsForHit1.get("details")).size()); + + Map hit4DetailsForHit2 = hit4Details.get(1); + assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit4DetailsForHit2.get("value"))); + assertTrue(((String) hit4DetailsForHit2.get("description")).matches("source scores: \\[.*, 0.0\\], combined score .*")); + assertEquals(0, ((List) hit4DetailsForHit2.get("details")).size()); + + Map hit4DetailsForHit3 = hit4Details.get(2); + assertEquals(1.0, (double) hit4DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertTrue(((String) hit4DetailsForHit3.get("description")).matches("combination of:")); + assertEquals(1, ((List) hit4DetailsForHit3.get("details")).size()); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index afc545447..9eb802708 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -48,6 +48,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import org.opensearch.search.sort.SortBuilder; @@ -1167,11 +1169,24 @@ protected void createSearchPipeline( final String normalizationMethod, String combinationMethod, final Map combinationParams + ) { + createSearchPipeline(pipelineId, normalizationMethod, combinationMethod, combinationParams, false); + } + + @SneakyThrows + protected void createSearchPipeline( + final String pipelineId, + final String normalizationMethod, + final String combinationMethod, + final Map combinationParams, + boolean addExplainResponseProcessor ) { StringBuilder stringBuilderForContentBody = new StringBuilder(); stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",") .append("\"phase_results_processors\": [{ ") - .append("\"normalization-processor\": {") + .append("\"") + .append(NormalizationProcessor.TYPE) + .append("\": {") .append("\"normalization\": {") .append("\"technique\": \"%s\"") .append("},") @@ -1184,7 +1199,15 @@ protected void createSearchPipeline( } stringBuilderForContentBody.append(" }"); } - stringBuilderForContentBody.append("}").append("}}]}"); + stringBuilderForContentBody.append("}").append("}}]"); + if (addExplainResponseProcessor) { + stringBuilderForContentBody.append(", \"response_processors\": [ ") + .append("{\"") + .append(ProcessorExplainPublisher.TYPE) + .append("\": {}}") + .append("]"); + } + stringBuilderForContentBody.append("}"); makeRequest( client(), "PUT", From f0fce087f0475ad8ce1408d6cb79c90d340621c0 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 31 Oct 2024 11:38:53 -0700 Subject: [PATCH 05/11] Adjust the format of final message Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/plugin/NeuralSearch.java | 6 +- ...her.java => ExplainResponseProcessor.java} | 10 +- .../NormalizationProcessorWorkflow.java | 4 +- ...ithmeticMeanScoreCombinationTechnique.java | 2 +- ...eometricMeanScoreCombinationTechnique.java | 2 +- ...HarmonicMeanScoreCombinationTechnique.java | 2 +- .../combination/ScoreCombinationUtil.java | 4 +- .../processor/combination/ScoreCombiner.java | 55 +++---- .../processor/explain/ExplainUtils.java | 80 +++------- .../ProcessorExplainPublisherFactory.java | 4 +- .../L2ScoreNormalizationTechnique.java | 2 +- .../MinMaxScoreNormalizationTechnique.java | 2 +- .../neuralsearch/query/HybridQueryWeight.java | 9 +- ...ava => ExplainResponseProcessorTests.java} | 18 +-- ...ticMeanScoreCombinationTechniqueTests.java | 4 +- ...ricMeanScoreCombinationTechniqueTests.java | 4 +- ...nicMeanScoreCombinationTechniqueTests.java | 4 +- .../neuralsearch/query/HybridQueryIT.java | 61 ++++--- .../neuralsearch/query/HybridQuerySortIT.java | 149 ++++++++++++++++++ .../neuralsearch/BaseNeuralSearchIT.java | 4 +- 21 files changed, 272 insertions(+), 155 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/{ProcessorExplainPublisher.java => ExplainResponseProcessor.java} (93%) rename src/test/java/org/opensearch/neuralsearch/processor/{ProcessorExplainPublisherTests.java => ExplainResponseProcessorTests.java} (63%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..a4b44d388 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features ### Enhancements +- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 5391faa14..e5272ab4b 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -32,7 +32,7 @@ import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; -import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; @@ -82,7 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); - public static final String PROCESSOR_EXPLAIN = "processor_explain"; + public static final String EXPLAIN_RESPONSE_KEY = "explain_response"; @Override public Collection createComponents( @@ -185,7 +185,7 @@ public Map explainPayload = processorExplainDto.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { Explanation processorExplanation = processorExplainDto.getExplanation(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 084a94bf7..5b6fa158c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -40,7 +40,7 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; @@ -152,7 +152,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(PROCESSOR_EXPLAIN, processorExplainDto); + pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 539eaa081..1d31b4c31 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -11,6 +11,7 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** @@ -20,7 +21,6 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "arithmetic_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 38f98c018..0dcd5c39c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -11,6 +11,7 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** @@ -20,7 +21,6 @@ public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "geometric_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index ddfbb5223..4fd112bc5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -11,6 +11,7 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** @@ -20,7 +21,6 @@ public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "harmonic_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index a915057df..5f18baf09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -23,8 +23,8 @@ * Collection of utility methods for score combination technique classes */ @Log4j2 -class ScoreCombinationUtil { - private static final String PARAM_NAME_WEIGHTS = "weights"; +public class ScoreCombinationUtil { + public static final String PARAM_NAME_WEIGHTS = "weights"; private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 7314c3383..8194ecf74 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -5,11 +5,9 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.Objects; @@ -31,6 +29,8 @@ import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument; + /** * Abstracts combination of scores in query search results. */ @@ -64,7 +64,7 @@ public class ScoreCombiner { * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", * other steps are same for all techniques. * - * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. + * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. */ public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from @@ -106,7 +106,6 @@ private void combineShardScores( updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs, topDocsPerSubQuery, - normalizedScoresPerDoc, combinedNormalizedScoresByDocId, sortedDocsIds, getDocIdSortFieldsMap(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sort), @@ -129,7 +128,7 @@ private boolean isSortOrderByScore(Sort sort) { } /** - * @param sort sort criteria + * @param sort sort criteria * @param topDocsPerSubQuery top docs per subquery * @return list of top field docs which is deduced by typcasting top docs to top field docs. */ @@ -149,9 +148,9 @@ private List getTopFieldDocs(final Sort sort, final List } /** - * @param compoundTopDocs top docs that represent on shard + * @param compoundTopDocs top docs that represent on shard * @param combinedNormalizedScoresByDocId docId to normalized scores map - * @param sort sort criteria + * @param sort sort criteria * @return map of docId and sort fields if sorting is enabled. */ private Map getDocIdSortFieldsMap( @@ -290,7 +289,6 @@ private Map combineScoresAndGetCombinedNormalizedScoresPerDocume private void updateQueryTopDocsWithCombinedScores( final CompoundTopDocs compoundQueryTopDocs, final List topDocsPerSubQuery, - Map normalizedScoresPerDoc, final Map combinedNormalizedScoresByDocId, final Collection sortedScores, Map docIdSortFieldMap, @@ -322,21 +320,21 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon public Map> explain( final List queryTopDocs, - ScoreCombinationTechnique combinationTechnique, - Sort sort + final ScoreCombinationTechnique combinationTechnique, + final Sort sort ) { // In case of duplicate keys, keep the first value - HashMap> map = new HashMap<>(); + HashMap> explanations = new HashMap<>(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { for (Map.Entry> docIdAtSearchShardExplainDetailsEntry : explainByShard( combinationTechnique, compoundQueryTopDocs, sort ).entrySet()) { - map.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); + explanations.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); } } - return map; + return explanations; } private Map> explainByShard( @@ -349,31 +347,20 @@ private Map> explainByShard( } // - create map of normalized scores results returned from the single shard Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs()); - Map> explainsForShard = new HashMap<>(); Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - SearchShard searchShard = compoundQueryTopDocs.getSearchShard(); - explainsForShard.put(searchShard, new ArrayList<>()); - for (Integer docId : sortedDocsIds) { - float combinedScore = combinedNormalizedScoresByDocId.get(docId); - explainsForShard.get(searchShard) - .add( - new ExplainDetails( - combinedScore, - String.format( - Locale.ROOT, - "source scores: %s, combined score %s", - Arrays.toString(normalizedScoresPerDoc.get(docId)), - combinedScore - ), - docId - ) - ); - } - - return explainsForShard; + List listOfExplainsForShard = sortedDocsIds.stream() + .map( + docId -> getScoreCombinationExplainDetailsForDocument( + docId, + combinedNormalizedScoresByDocId, + normalizedScoresPerDoc.get(docId) + ) + ) + .toList(); + return Map.of(compoundQueryTopDocs.getSearchShard(), listOfExplainsForShard); } private Collection getSortedDocsIds( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java index 0de386ca7..a66acb786 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java @@ -5,16 +5,15 @@ package org.opensearch.neuralsearch.processor.explain; import org.apache.lucene.search.Explanation; -import org.opensearch.neuralsearch.processor.SearchShard; -import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + /** * Utility class for explain functionality */ @@ -37,65 +36,36 @@ public static Map getDocIdAtQueryForNormaliz List normScores = normalizedScores.get(entry.getKey()); return new ExplainDetails( normScores.stream().reduce(0.0f, Float::max), - String.format(Locale.ROOT, "source scores: %s, normalized scores: %s", srcScores, normScores) + String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores) ); })); return explain; } /** - * Creates map of DocIdAtQueryPhase to String containing source scores and combined score - * @param scoreCombinationTechnique the combination technique used - * @param normalizedScoresPerDoc map of DocIdAtQueryPhase to normalized scores - * @param searchShard the search shard - * @return map of DocIdAtQueryPhase to String containing source scores and combined score + * Return the detailed score combination explain for the single document + * @param docId + * @param combinedNormalizedScoresByDocId + * @param normalizedScoresPerDoc + * @return */ - public static Map getDocIdAtQueryForCombination( - ScoreCombinationTechnique scoreCombinationTechnique, - Map normalizedScoresPerDoc, - SearchShard searchShard - ) { - Map explain = new HashMap<>(); - // - create map of combined scores per doc id - Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); - normalizedScoresPerDoc.forEach((key, srcScores) -> { - float combinedScore = combinedNormalizedScoresByDocId.get(key); - explain.put( - new DocIdAtSearchShard(key, searchShard), - new ExplainDetails( - combinedScore, - String.format(Locale.ROOT, "source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) - ) - ); - }); - return explain; - } - - /* public static Map> getExplainsByShardForCombination( - ScoreCombinationTechnique scoreCombinationTechnique, - Map normalizedScoresPerDoc, - SearchShard searchShard + public static ExplainDetails getScoreCombinationExplainDetailsForDocument( + final Integer docId, + final Map combinedNormalizedScoresByDocId, + final float[] normalizedScoresPerDoc ) { - Map explain = new HashMap<>(); - // - create map of combined scores per doc id - Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); - normalizedScoresPerDoc.forEach((key, srcScores) -> { - float combinedScore = combinedNormalizedScoresByDocId.get(key); - explain.put( - new DocIdAtSearchShard(key, searchShard), - new ExplainDetails( - combinedScore, - String.format("source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) - ) - ); - }); - return explain; + float combinedScore = combinedNormalizedScoresByDocId.get(docId); + return new ExplainDetails( + combinedScore, + String.format( + Locale.ROOT, + "normalized scores: %s combined to a final score: %s", + Arrays.toString(normalizedScoresPerDoc), + combinedScore + ), + docId + ); } - */ /** * Creates a string describing the combination technique and its parameters @@ -104,7 +74,7 @@ public static Map getDocIdAtQueryForCombinat * @return a string describing the combination technique and its parameters */ public static String describeCombinationTechnique(final String techniqueName, final List weights) { - return String.format(Locale.ROOT, "combination technique [%s] with optional parameters [%s]", techniqueName, weights); + return String.format(Locale.ROOT, "combination [%s] with optional parameter [%s]: %s", techniqueName, PARAM_NAME_WEIGHTS, weights); } /** @@ -119,7 +89,7 @@ public static Explanation topLevelExpalantionForCombinedScore( ) { String explanationDetailsMessage = String.format( Locale.ROOT, - "combine score with techniques: %s, %s", + "combined score with techniques: %s, %s", explainableNormalizationTechnique.describe(), explainableCombinationTechnique.describe() ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java index 2633a89ad..cbca8daca 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java @@ -4,7 +4,7 @@ */ package org.opensearch.neuralsearch.processor.factory; -import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -21,6 +21,6 @@ public SearchResponseProcessor create( Map config, Processor.PipelineContext pipelineContext ) throws Exception { - return new ProcessorExplainPublisher(description, tag, ignoreFailure); + return new ExplainResponseProcessor(description, tag, ignoreFailure); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index e1eb6c56a..5c5436564 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -60,7 +60,7 @@ public void normalize(final List queryTopDocs) { @Override public String describe() { - return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); + return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 8e0f0f586..63efb4332 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -74,7 +74,7 @@ public void normalize(final List queryTopDocs) { @Override public String describe() { - return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); + return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 08393922a..c0c91be6e 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -140,11 +140,10 @@ public boolean isCacheable(LeafReaderContext ctx) { } /** - * Explain is not supported for hybrid query - * + * Returns a shard level {@link Explanation} that describes how the weight and scoring are calculated. * @param context the readers context to create the {@link Explanation} for. - * @param doc the document's id relative to the given context's reader - * @return + * @param doc the document's id relative to the given context's reader + * @return shard level {@link Explanation}, each sub-query explanation is a single nested element * @throws IOException */ @Override @@ -165,7 +164,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio } } if (match) { - final String desc = "combination of:"; + final String desc = "base scores from subqueries:"; return Explanation.match(max, desc, subsOnMatch); } else { return Explanation.noMatch("no matching clause", subsOnNoMatch); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java similarity index 63% rename from src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java index 0f5abfec4..796a899be 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java @@ -13,21 +13,21 @@ import static org.mockito.Mockito.mock; -public class ProcessorExplainPublisherTests extends OpenSearchTestCase { +public class ExplainResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { - ProcessorExplainPublisher processorExplainPublisher = new ProcessorExplainPublisher(DESCRIPTION, PROCESSOR_TAG, false); + ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); - assertEquals(DESCRIPTION, processorExplainPublisher.getDescription()); - assertEquals(PROCESSOR_TAG, processorExplainPublisher.getTag()); - assertFalse(processorExplainPublisher.isIgnoreFailure()); + assertEquals(DESCRIPTION, explainResponseProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, explainResponseProcessor.getTag()); + assertFalse(explainResponseProcessor.isIgnoreFailure()); } @SneakyThrows public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { - ProcessorExplainPublisher processorExplainPublisher = new ProcessorExplainPublisher(DESCRIPTION, PROCESSOR_TAG, false); + ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); SearchRequest searchRequest = mock(SearchRequest.class); SearchResponse searchResponse = new SearchResponse( null, @@ -40,14 +40,14 @@ public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProc SearchResponse.Clusters.EMPTY ); - SearchResponse processedResponse = processorExplainPublisher.processResponse(searchRequest, searchResponse); + SearchResponse processedResponse = explainResponseProcessor.processResponse(searchRequest, searchResponse); assertEquals(searchResponse, processedResponse); - SearchResponse processedResponse2 = processorExplainPublisher.processResponse(searchRequest, searchResponse, null); + SearchResponse processedResponse2 = explainResponseProcessor.processResponse(searchRequest, searchResponse, null); assertEquals(searchResponse, processedResponse2); PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - SearchResponse processedResponse3 = processorExplainPublisher.processResponse( + SearchResponse processedResponse3 = explainResponseProcessor.processResponse( searchRequest, searchResponse, pipelineProcessingContext diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 7d3b3fb61..deac02933 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index 495e2f4cd..d46705902 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 0c6e1f81d..0cfdeb4c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index e24a4da5b..ce75d9fba 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -891,24 +891,26 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertNotNull(explanationForHit1); assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedGeneralCombineScoreDescription = - "combine score with techniques: normalization technique [min_max], combination technique [arithmetic_mean] with optional parameters [[]]"; + "combined score with techniques: normalization [min_max], combination [arithmetic_mean] with optional parameter [weights]: []"; assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); List> hit1Details = (List>) explanationForHit1.get("details"); assertEquals(3, hit1Details.size()); Map hit1DetailsForHit1 = hit1Details.get(0); assertEquals(1.0, hit1DetailsForHit1.get("value")); - assertTrue(((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[1\\.0\\]")); + assertTrue( + ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[1\\.0\\]") + ); assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); Map hit1DetailsForHit2 = hit1Details.get(1); assertEquals(0.5, hit1DetailsForHit2.get("value")); - assertEquals("source scores: [0.0, 1.0], combined score 0.5", hit1DetailsForHit2.get("description")); + assertEquals("normalized scores: [0.0, 1.0] combined to a final score: 0.5", hit1DetailsForHit2.get("description")); assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); Map hit1DetailsForHit3 = hit1Details.get(2); double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); assertTrue(actualHit1ScoreHit3 > 0.0); - assertEquals("combination of:", hit1DetailsForHit3.get("description")); + assertEquals("base scores from subqueries:", hit1DetailsForHit3.get("description")); assertEquals(1, ((List) hit1DetailsForHit3.get("details")).size()); Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); @@ -926,18 +928,20 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertEquals(3, hit2Details.size()); Map hit2DetailsForHit1 = hit2Details.get(0); assertEquals(1.0, hit2DetailsForHit1.get("value")); - assertTrue(((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[1\\.0\\]")); + assertTrue( + ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[1\\.0\\]") + ); assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); Map hit2DetailsForHit2 = hit2Details.get(1); assertEquals(0.5, hit2DetailsForHit2.get("value")); - assertEquals("source scores: [1.0, 0.0], combined score 0.5", hit2DetailsForHit2.get("description")); + assertEquals("normalized scores: [1.0, 0.0] combined to a final score: 0.5", hit2DetailsForHit2.get("description")); assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); Map hit2DetailsForHit3 = hit2Details.get(2); double actualHit2ScoreHit3 = ((double) hit2DetailsForHit3.get("value")); assertTrue(actualHit2ScoreHit3 > 0.0); - assertEquals("combination of:", hit2DetailsForHit3.get("description")); + assertEquals("base scores from subqueries:", hit2DetailsForHit3.get("description")); assertEquals(1, ((List) hit2DetailsForHit3.get("details")).size()); Map hit2SubDetailsForHit3 = (Map) ((List) hit2DetailsForHit3.get("details")).get(0); @@ -956,19 +960,19 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { Map hit3DetailsForHit1 = hit3Details.get(0); assertEquals(0.001, hit3DetailsForHit1.get("value")); assertTrue( - ((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[0\\.001\\]") + ((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[0\\.001\\]") ); assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); Map hit3DetailsForHit2 = hit3Details.get(1); assertEquals(0.0005, hit3DetailsForHit2.get("value")); - assertEquals("source scores: [0.0, 0.001], combined score 5.0E-4", hit3DetailsForHit2.get("description")); + assertEquals("normalized scores: [0.0, 0.001] combined to a final score: 5.0E-4", hit3DetailsForHit2.get("description")); assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); Map hit3DetailsForHit3 = hit3Details.get(2); double actualHit3ScoreHit3 = ((double) hit3DetailsForHit3.get("value")); assertTrue(actualHit3ScoreHit3 > 0.0); - assertEquals("combination of:", hit3DetailsForHit3.get("description")); + assertEquals("base scores from subqueries:", hit3DetailsForHit3.get("description")); assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); Map hit3SubDetailsForHit3 = (Map) ((List) hit3DetailsForHit3.get("details")).get(0); @@ -1027,27 +1031,28 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertNotNull(explanationForHit1); assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedGeneralCombineScoreDescription = - "combine score with techniques: normalization technique [l2], combination technique [arithmetic_mean] with optional parameters [" - + Arrays.toString(new float[] { 0.3f, 0.7f }) - + "]"; + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameter [weights]: " + + Arrays.toString(new float[] { 0.3f, 0.7f }); assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); List> hit1Details = (List>) explanationForHit1.get("details"); assertEquals(3, hit1Details.size()); Map hit1DetailsForHit1 = hit1Details.get(0); assertTrue((double) hit1DetailsForHit1.get("value") > 0.5f); assertTrue( - ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\], normalized scores: \\[.*, .*\\]") + ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\] normalized to scores: \\[.*, .*\\]") ); assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); Map hit1DetailsForHit2 = hit1Details.get(1); assertEquals(actualMaxScore, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit1DetailsForHit2.get("description")).matches("source scores: \\[.*, .*\\], combined score .*")); + assertTrue( + ((String) hit1DetailsForHit2.get("description")).matches("normalized scores: \\[.*, .*\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); Map hit1DetailsForHit3 = hit1Details.get(2); assertEquals(1.0, (double) hit1DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit1DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit1DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); // hit 2 @@ -1062,18 +1067,20 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() Map hit2DetailsForHit1 = hit2Details.get(0); assertTrue((double) hit2DetailsForHit1.get("value") > 0.5f); assertTrue( - ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\], normalized scores: \\[.*, .*\\]") + ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\] normalized to scores: \\[.*, .*\\]") ); assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); Map hit2DetailsForHit2 = hit2Details.get(1); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit2DetailsForHit2.get("value"))); - assertTrue(((String) hit2DetailsForHit2.get("description")).matches("source scores: \\[.*, .*\\], combined score .*")); + assertTrue( + ((String) hit2DetailsForHit2.get("description")).matches("normalized scores: \\[.*, .*\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); Map hit2DetailsForHit3 = hit2Details.get(2); assertEquals(1.0, (double) hit2DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit2DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit2DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(2, ((List) hit2DetailsForHit3.get("details")).size()); // hit 3 @@ -1087,17 +1094,19 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(3, hit3Details.size()); Map hit3DetailsForHit1 = hit3Details.get(0); assertTrue((double) hit3DetailsForHit1.get("value") > 0.5f); - assertTrue(((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[.*\\]")); + assertTrue(((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[.*\\]")); assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); Map hit3DetailsForHit2 = hit3Details.get(1); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit2.get("value"))); - assertTrue(((String) hit3DetailsForHit2.get("description")).matches("source scores: \\[0.0, .*\\], combined score .*")); + assertTrue( + ((String) hit3DetailsForHit2.get("description")).matches("normalized scores: \\[0.0, .*\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); Map hit3DetailsForHit3 = hit3Details.get(2); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit3.get("value"))); - assertTrue(((String) hit3DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit3DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); // hit 4 @@ -1111,17 +1120,19 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(3, hit4Details.size()); Map hit4DetailsForHit1 = hit4Details.get(0); assertTrue((double) hit4DetailsForHit1.get("value") > 0.5f); - assertTrue(((String) hit4DetailsForHit1.get("description")).matches("source scores: \\[1.0\\], normalized scores: \\[.*\\]")); + assertTrue(((String) hit4DetailsForHit1.get("description")).matches("source scores: \\[1.0\\] normalized to scores: \\[.*\\]")); assertEquals(0, ((List) hit4DetailsForHit1.get("details")).size()); Map hit4DetailsForHit2 = hit4Details.get(1); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit4DetailsForHit2.get("value"))); - assertTrue(((String) hit4DetailsForHit2.get("description")).matches("source scores: \\[.*, 0.0\\], combined score .*")); + assertTrue( + ((String) hit4DetailsForHit2.get("description")).matches("normalized scores: \\[.*, 0.0\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit4DetailsForHit2.get("details")).size()); Map hit4DetailsForHit3 = hit4Details.get(2); assertEquals(1.0, (double) hit4DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit4DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit4DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(1, ((List) hit4DetailsForHit3.get("details")).size()); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index b5e812780..86f2dc620 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -21,6 +21,9 @@ import org.opensearch.neuralsearch.BaseNeuralSearchIT; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQueryWhenSortIsEnabled; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import org.opensearch.search.sort.SortOrder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; @@ -531,6 +534,152 @@ public void testSortingWithRescoreWhenConcurrentSegmentSearchEnabledAndDisabled_ } } + @SneakyThrows + public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { + try { + // Setup + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + // Assert + // scores for search hits + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + null, + 0 + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6); + assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, true); + + // explain + Map searchHit1 = nestedHits.get(0); + Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(explanationForHit1); + assertNull(searchHit1.get("_score")); + String expectedGeneralCombineScoreDescription = + "combined score with techniques: normalization [min_max], combination [arithmetic_mean] with optional parameter [weights]: []"; + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); + List> hit1Details = (List>) explanationForHit1.get("details"); + assertEquals(3, hit1Details.size()); + Map hit1DetailsForHit1 = hit1Details.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + assertTrue( + ((String) hit1DetailsForHit1.get("description")).matches( + "source scores: \\[0.4700036, 1.0\\] normalized to scores: \\[1.0, 1.0\\]" + ) + ); + assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertEquals(0.6666667, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("normalized scores: [1.0, 0.0, 1.0] combined to a final score: 0.6666667", hit1DetailsForHit2.get("description")); + assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); + + Map hit1DetailsForHit3 = hit1Details.get(2); + double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); + assertTrue(actualHit1ScoreHit3 > 0.0); + assertEquals("base scores from subqueries:", hit1DetailsForHit3.get("description")); + assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); + + Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); + assertEquals(0.47, ((double) hit1SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(name:mission in 0) [PerFieldSimilarity], result of:", hit1SubDetailsForHit3.get("description")); + assertEquals(1, ((List) hit1SubDetailsForHit3.get("details")).size()); + + Map hit2SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(1); + assertEquals(1.0f, ((double) hit2SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("stock:[20 TO 400]", hit2SubDetailsForHit3.get("description")); + assertEquals(0, ((List) hit2SubDetailsForHit3.get("details")).size()); + // hit 4 + Map searchHit4 = nestedHits.get(3); + Map explanationForHit4 = (Map) searchHit4.get("_explanation"); + assertNotNull(explanationForHit4); + assertNull(searchHit4.get("_score")); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit4.get("description")); + List> hit4Details = (List>) explanationForHit4.get("details"); + assertEquals(3, hit4Details.size()); + Map hit1DetailsForHit4 = hit4Details.get(0); + assertEquals(1.0, hit1DetailsForHit4.get("value")); + assertTrue( + ((String) hit1DetailsForHit4.get("description")).matches( + "source scores: \\[0.9808291, 1.0\\] normalized to scores: \\[1.0, 1.0\\]" + ) + ); + assertEquals(0, ((List) hit1DetailsForHit4.get("details")).size()); + + Map hit2DetailsForHit4 = hit4Details.get(1); + assertEquals(0.6666667, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("normalized scores: [0.0, 1.0, 1.0] combined to a final score: 0.6666667", hit2DetailsForHit4.get("description")); + assertEquals(0, ((List) hit2DetailsForHit4.get("details")).size()); + + Map hit3DetailsForHit4 = hit4Details.get(2); + double actualHit3ScoreHit4 = ((double) hit3DetailsForHit4.get("value")); + assertTrue(actualHit3ScoreHit4 > 0.0); + assertEquals("base scores from subqueries:", hit3DetailsForHit4.get("description")); + assertEquals(2, ((List) hit3DetailsForHit4.get("details")).size()); + + Map hit1SubDetailsForHit4 = (Map) ((List) hit3DetailsForHit4.get("details")).get(0); + assertEquals(0.98, ((double) hit1SubDetailsForHit4.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1SubDetailsForHit4.get("description")); + assertEquals(1, ((List) hit1SubDetailsForHit4.get("details")).size()); + + Map hit2SubDetailsForHit4 = (Map) ((List) hit3DetailsForHit4.get("details")).get(1); + assertEquals(1.0f, ((double) hit2SubDetailsForHit4.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("stock:[20 TO 400]", hit2SubDetailsForHit4.get("description")); + assertEquals(0, ((List) hit2SubDetailsForHit4.get("details")).size()); + + // hit 6 + Map searchHit6 = nestedHits.get(5); + Map explanationForHit6 = (Map) searchHit6.get("_explanation"); + assertNotNull(explanationForHit6); + assertNull(searchHit6.get("_score")); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit6.get("description")); + List> hit6Details = (List>) explanationForHit6.get("details"); + assertEquals(3, hit6Details.size()); + Map hit1DetailsForHit6 = hit6Details.get(0); + assertEquals(1.0, hit1DetailsForHit6.get("value")); + assertEquals("source scores: [1.0] normalized to scores: [1.0]", hit1DetailsForHit6.get("description")); + assertEquals(0, ((List) hit1DetailsForHit6.get("details")).size()); + + Map hit2DetailsForHit6 = hit6Details.get(1); + assertEquals(0.333, (double) hit2DetailsForHit6.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("normalized scores: [0.0, 0.0, 1.0] combined to a final score: 0.33333334", hit2DetailsForHit6.get("description")); + assertEquals(0, ((List) hit2DetailsForHit6.get("details")).size()); + + Map hit3DetailsForHit6 = hit6Details.get(2); + double actualHit3ScoreHit6 = ((double) hit3DetailsForHit6.get("value")); + assertTrue(actualHit3ScoreHit6 > 0.0); + assertEquals("base scores from subqueries:", hit3DetailsForHit6.get("description")); + assertEquals(1, ((List) hit3DetailsForHit6.get("details")).size()); + + Map hit1SubDetailsForHit6 = (Map) ((List) hit3DetailsForHit6.get("details")).get(0); + assertEquals(1.0, ((double) hit1SubDetailsForHit6.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("stock:[20 TO 400]", hit1SubDetailsForHit6.get("description")); + assertEquals(0, ((List) hit1SubDetailsForHit6.get("details")).size()); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 9eb802708..1d666d788 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -49,7 +49,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.processor.NormalizationProcessor; -import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import org.opensearch.search.sort.SortBuilder; @@ -1203,7 +1203,7 @@ protected void createSearchPipeline( if (addExplainResponseProcessor) { stringBuilderForContentBody.append(", \"response_processors\": [ ") .append("{\"") - .append(ProcessorExplainPublisher.TYPE) + .append(ExplainResponseProcessor.TYPE) .append("\": {}}") .append("]"); } From a25acc5498264968f0cb04ea6acc0bc3124599dc Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 1 Nov 2024 13:58:04 -0700 Subject: [PATCH 06/11] Doing some refactoring Signed-off-by: Martin Gaievski --- .../neuralsearch/plugin/NeuralSearch.java | 8 +- .../processor/CompoundTopDocs.java | 6 +- ...java => ExplanationResponseProcessor.java} | 21 +- .../processor/NormalizationProcessor.java | 22 +- .../NormalizationProcessorWorkflow.java | 8 +- ...zationProcessorWorkflowExecuteRequest.java | 3 + .../neuralsearch/processor/SearchShard.java | 10 +- .../explain/CombinedExplainDetails.java | 3 + .../processor/explain/DocIdAtSearchShard.java | 2 +- .../processor/explain/ExplainDetails.java | 7 +- .../processor/explain/ExplainUtils.java | 5 +- ...plainDto.java => ExplanationResponse.java} | 5 +- ... ExplanationResponseProcessorFactory.java} | 9 +- .../ExplainResponseProcessorTests.java | 57 -- .../ExplanationResponseProcessorTests.java | 530 ++++++++++++++++++ .../neuralsearch/query/HybridQuerySortIT.java | 4 +- .../neuralsearch/BaseNeuralSearchIT.java | 4 +- 17 files changed, 601 insertions(+), 103 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/{ExplainResponseProcessor.java => ExplanationResponseProcessor.java} (83%) rename src/main/java/org/opensearch/neuralsearch/processor/explain/{ProcessorExplainDto.java => ExplanationResponse.java} (80%) rename src/main/java/org/opensearch/neuralsearch/processor/factory/{ProcessorExplainPublisherFactory.java => ExplanationResponseProcessorFactory.java} (65%) delete mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index e5272ab4b..8deabd141 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -32,14 +32,14 @@ import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; -import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; -import org.opensearch.neuralsearch.processor.factory.ProcessorExplainPublisherFactory; +import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; @@ -185,8 +185,8 @@ public Map topDocs, boolean isSo public CompoundTopDocs(final QuerySearchResult querySearchResult) { final TopDocs topDocs = querySearchResult.topDocs().topDocs; final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget(); - SearchShard searchShard = new SearchShard( - searchShardTarget.getIndex(), - searchShardTarget.getShardId().id(), - searchShardTarget.getNodeId() - ); + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); boolean isSortEnabled = false; if (topDocs instanceof TopFieldDocs) { isSortEnabled = true; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java similarity index 83% rename from src/main/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 4aee8fe6e..4cfaf9837 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -10,7 +10,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto; +import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; @@ -24,13 +24,16 @@ import java.util.Objects; import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; -import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR; +import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR; +/** + * Processor to add explanation details to search response + */ @Getter @AllArgsConstructor -public class ExplainResponseProcessor implements SearchResponseProcessor { +public class ExplanationResponseProcessor implements SearchResponseProcessor { - public static final String TYPE = "explain_response_processor"; + public static final String TYPE = "explanation_response_processor"; private final String description; private final String tag; @@ -46,10 +49,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) { return response; } - ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); - Map explainPayload = processorExplainDto.getExplainPayload(); + ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); + Map explainPayload = explanationResponse.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { - Explanation processorExplanation = processorExplainDto.getExplanation(); + Explanation processorExplanation = explanationResponse.getExplanation(); if (Objects.isNull(processorExplanation)) { return response; } @@ -62,7 +65,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp for (int i = 0; i < searchHitsArray.length; i++) { SearchHit searchHit = searchHitsArray[i]; SearchShardTarget searchShardTarget = searchHit.getShard(); - SearchShard searchShard = SearchShard.create(searchShardTarget); + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); explainsByShardCount.putIfAbsent(searchShard, -1); } @@ -73,7 +76,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp List>) explainPayload.get(NORMALIZATION_PROCESSOR); for (SearchHit searchHit : searchHitsArray) { - SearchShard searchShard = SearchShard.create(searchHit.getShard()); + SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); Explanation normalizedExplanation = Explanation.match( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 6c6327b3b..7f0314ef7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -44,7 +44,7 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { /** * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor + * are set as part of class constructor. This method is called when there is no pipeline context * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution * @param searchPhaseContext {@link SearchContext} */ @@ -53,19 +53,27 @@ public void process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext ) { - doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.empty()); + prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); } + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ @Override public void process( - SearchPhaseResults searchPhaseResult, - SearchPhaseContext searchPhaseContext, - PipelineProcessingContext requestContext + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext ) { - doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); } - private void doProcessStuff( + private void prepareAndExecuteNormalizationWorkflow( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional requestContextOptional diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 5b6fa158c..118d0a25c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -28,7 +28,7 @@ import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplainDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto; +import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -146,13 +146,13 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< } }); - ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() + ExplanationResponse explanationResponse = ExplanationResponse.builder() .explanation(topLevelExplanationForTechniques) - .explainPayload(Map.of(ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) + .explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto); + pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java index 8056bd100..ea0b54b9c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -19,6 +19,9 @@ @Builder @AllArgsConstructor @Getter +/** + * DTO class to hold request parameters for normalization and combination + */ public class NormalizationProcessorWorkflowExecuteRequest { final List querySearchResults; final Optional fetchSearchResultOptional; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java index 61f9d0be5..505b19ae0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -6,9 +6,17 @@ import org.opensearch.search.SearchShardTarget; +/** + * DTO class to store index, shardId and nodeId for a search shard. + */ public record SearchShard(String index, int shardId, String nodeId) { - public static SearchShard create(SearchShardTarget searchShardTarget) { + /** + * Create SearchShard from SearchShardTarget + * @param searchShardTarget + * @return SearchShard + */ + public static SearchShard createSearchShard(final SearchShardTarget searchShardTarget) { return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java index 8793f5d53..4a9793fd4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java @@ -8,6 +8,9 @@ import lombok.Builder; import lombok.Getter; +/** + * DTO class to hold explain details for normalization and combination + */ @AllArgsConstructor @Builder @Getter diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java index 70da8b73c..9ce4ebf97 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java @@ -7,7 +7,7 @@ import org.opensearch.neuralsearch.processor.SearchShard; /** - * Data class to store docId and search shard for a query. + * DTO class to store docId and search shard for a query. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. * @param docId * @param searchShard diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java index 0bbdea27f..ca83cc436 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java @@ -5,14 +5,13 @@ package org.opensearch.neuralsearch.processor.explain; /** - * Data class to store value and description for explain details. + * DTO class to store value and description for explain details. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. * @param value * @param description */ -public record ExplainDetails(float value, String description, int docId) { - +public record ExplainDetails(int docId, float value, String description) { public ExplainDetails(float value, String description) { - this(value, description, -1); + this(-1, value, description); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java index a66acb786..1eca4232f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java @@ -56,14 +56,14 @@ public static ExplainDetails getScoreCombinationExplainDetailsForDocument( ) { float combinedScore = combinedNormalizedScoresByDocId.get(docId); return new ExplainDetails( + docId, combinedScore, String.format( Locale.ROOT, "normalized scores: %s combined to a final score: %s", Arrays.toString(normalizedScoresPerDoc), combinedScore - ), - docId + ) ); } @@ -96,5 +96,4 @@ public static Explanation topLevelExpalantionForCombinedScore( return Explanation.match(0.0f, explanationDetailsMessage); } - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java similarity index 80% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java index ed29cd6e3..f14050214 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java @@ -11,10 +11,13 @@ import java.util.Map; +/** + * DTO class to hold explain details for normalization and combination + */ @AllArgsConstructor @Builder @Getter -public class ProcessorExplainDto { +public class ExplanationResponse { Explanation explanation; Map explainPayload; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java similarity index 65% rename from src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java rename to src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java index cbca8daca..ffe1da12f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java @@ -4,13 +4,16 @@ */ package org.opensearch.neuralsearch.processor.factory; -import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; import java.util.Map; -public class ProcessorExplainPublisherFactory implements Processor.Factory { +/** + * Factory class for creating ExplanationResponseProcessor + */ +public class ExplanationResponseProcessorFactory implements Processor.Factory { @Override public SearchResponseProcessor create( @@ -21,6 +24,6 @@ public SearchResponseProcessor create( Map config, Processor.PipelineContext pipelineContext ) throws Exception { - return new ExplainResponseProcessor(description, tag, ignoreFailure); + return new ExplanationResponseProcessor(description, tag, ignoreFailure); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java deleted file mode 100644 index 796a899be..000000000 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.processor; - -import lombok.SneakyThrows; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.search.pipeline.PipelineProcessingContext; -import org.opensearch.test.OpenSearchTestCase; - -import static org.mockito.Mockito.mock; - -public class ExplainResponseProcessorTests extends OpenSearchTestCase { - private static final String PROCESSOR_TAG = "mockTag"; - private static final String DESCRIPTION = "mockDescription"; - - public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { - ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); - - assertEquals(DESCRIPTION, explainResponseProcessor.getDescription()); - assertEquals(PROCESSOR_TAG, explainResponseProcessor.getTag()); - assertFalse(explainResponseProcessor.isIgnoreFailure()); - } - - @SneakyThrows - public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { - ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); - SearchRequest searchRequest = mock(SearchRequest.class); - SearchResponse searchResponse = new SearchResponse( - null, - null, - 1, - 1, - 0, - 1000, - new ShardSearchFailure[0], - SearchResponse.Clusters.EMPTY - ); - - SearchResponse processedResponse = explainResponseProcessor.processResponse(searchRequest, searchResponse); - assertEquals(searchResponse, processedResponse); - - SearchResponse processedResponse2 = explainResponseProcessor.processResponse(searchRequest, searchResponse, null); - assertEquals(searchResponse, processedResponse2); - - PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - SearchResponse processedResponse3 = explainResponseProcessor.processResponse( - searchRequest, - searchResponse, - pipelineProcessingContext - ); - assertEquals(searchResponse, processedResponse3); - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java new file mode 100644 index 000000000..ce2df0b13 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -0,0 +1,530 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.commons.lang3.Range; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.RemoteClusterAware; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.TreeMap; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; + +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + + assertEquals(DESCRIPTION, explanationResponseProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, explanationResponseProcessor.getTag()); + assertFalse(explanationResponseProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = new SearchResponse( + null, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + + SearchResponse processedResponse = explanationResponseProcessor.processResponse(searchRequest, searchResponse); + assertEquals(searchResponse, processedResponse); + + SearchResponse processedResponse2 = explanationResponseProcessor.processResponse(searchRequest, searchResponse, null); + assertEquals(searchResponse, processedResponse2); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + SearchResponse processedResponse3 = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + assertEquals(searchResponse, processedResponse3); + } + + @SneakyThrows + public void testParsingOfExplanations_whenResponseHasExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(null)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + float maxScore = numDocs * scoreFactor; + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + null, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(numResponses, TotalHits.Relation.EQUAL_TO), 1.0f); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Explanation generalExplanation = Explanation.match( + maxScore, + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" + ); + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHitArray[0].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .build() + ), + SearchShard.createSearchShard(searchHitArray[1].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .build() + ) + ); + Map explainPayload = Map.of( + ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationResponse explanationResponse = ExplanationResponse.builder() + .explanation(generalExplanation) + .explainPayload(explainPayload) + .build(); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + @SneakyThrows + public void testParsingOfExplanations_whenFieldSortingAndExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + final SortField[] sortFields = new SortField[] { + new SortField("random-text-field-1", SortField.Type.INT, randomBoolean()), + new SortField("random-text-field-2", SortField.Type.STRING, randomBoolean()) }; + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + float maxScore = Float.NaN; + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + sortFields, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, totalHits, maxScore, sortFields, null, null); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Explanation generalExplanation = Explanation.match( + maxScore, + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" + ); + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHitArray[0].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .build() + ), + SearchShard.createSearchShard(searchHitArray[1].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .build() + ) + ); + Map explainPayload = Map.of( + ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationResponse explanationResponse = ExplanationResponse.builder() + .explanation(generalExplanation) + .explainPayload(explainPayload) + .build(); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + @SneakyThrows + public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + final SortField[] sortFields = new SortField[] { SortField.FIELD_SCORE }; + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + float maxScore = Float.NaN; + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + sortFields, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, totalHits, maxScore, sortFields, null, null); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Explanation generalExplanation = Explanation.match( + maxScore, + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" + ); + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHitArray[0].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .build() + ), + SearchShard.createSearchShard(searchHitArray[1].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .build() + ) + ); + Map explainPayload = Map.of( + ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationResponse explanationResponse = ExplanationResponse.builder() + .explanation(generalExplanation) + .explainPayload(explainPayload) + .build(); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + private static SearchResponse getSearchResponse(SearchHits searchHits) { + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + internalSearchResponse, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + return searchResponse; + } + + private static void assertOnExplanationResults(SearchResponse processedResponse, float maxScore) { + assertNotNull(processedResponse); + Explanation explanationHit1 = processedResponse.getHits().getHits()[0].getExplanation(); + assertNotNull(explanationHit1); + assertEquals( + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]", + explanationHit1.getDescription() + ); + assertEquals(maxScore, (float) explanationHit1.getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation[] detailsHit1 = explanationHit1.getDetails(); + assertEquals(3, detailsHit1.length); + assertEquals("source scores: [1.0] normalized to scores: [0.5]", detailsHit1[0].getDescription()); + assertEquals(1.0f, (float) detailsHit1[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("normalized scores: [0.5] combined to a final score: 0.5", detailsHit1[1].getDescription()); + assertEquals(0.5f, (float) detailsHit1[1].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("base scores from subqueries:", detailsHit1[2].getDescription()); + assertEquals(1.0f, (float) detailsHit1[2].getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation explanationHit2 = processedResponse.getHits().getHits()[1].getExplanation(); + assertNotNull(explanationHit2); + assertEquals( + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]", + explanationHit2.getDescription() + ); + assertTrue(Range.of(0.0f, maxScore).contains((float) explanationHit2.getValue())); + + Explanation[] detailsHit2 = explanationHit2.getDetails(); + assertEquals(3, detailsHit2.length); + assertEquals("source scores: [0.5] normalized to scores: [0.25]", detailsHit2[0].getDescription()); + assertEquals(.5f, (float) detailsHit2[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("normalized scores: [0.25] combined to a final score: 0.25", detailsHit2[1].getDescription()); + assertEquals(.25f, (float) detailsHit2[1].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("base scores from subqueries:", detailsHit2[2].getDescription()); + assertEquals(1.0f, (float) detailsHit2[2].getValue(), DELTA_FOR_SCORE_ASSERTION); + } + + private static Map randomRealisticIndices(int numIndices, int numClusters) { + String[] indicesNames = new String[numIndices]; + for (int i = 0; i < numIndices; i++) { + indicesNames[i] = randomAlphaOfLengthBetween(5, 10); + } + Map indicesPerCluster = new TreeMap<>(); + for (int i = 0; i < numClusters; i++) { + Index[] indices = new Index[indicesNames.length]; + for (int j = 0; j < indices.length; j++) { + String indexName = indicesNames[j]; + String indexUuid = frequently() ? randomAlphaOfLength(10) : indexName; + indices[j] = new Index(indexName, indexUuid); + } + String clusterAlias; + if (frequently() || indicesPerCluster.containsKey(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY)) { + clusterAlias = randomAlphaOfLengthBetween(5, 10); + } else { + clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + } + indicesPerCluster.put(clusterAlias, indices); + } + return indicesPerCluster; + } + + private static SearchHit[] randomSearchHitArray( + int numDocs, + int numResponses, + String clusterAlias, + Index[] indices, + float maxScore, + int scoreFactor, + SortField[] sortFields, + PriorityQueue priorityQueue + ) { + SearchHit[] hits = new SearchHit[numDocs]; + + int[] sortFieldFactors = new int[sortFields == null ? 0 : sortFields.length]; + for (int j = 0; j < sortFieldFactors.length; j++) { + sortFieldFactors[j] = randomIntBetween(1, numResponses); + } + + for (int j = 0; j < numDocs; j++) { + ShardId shardId = new ShardId(randomFrom(indices), randomIntBetween(0, 10)); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLengthBetween(3, 8), + shardId, + clusterAlias, + OriginalIndices.NONE + ); + SearchHit hit = new SearchHit(randomIntBetween(0, Integer.MAX_VALUE)); + + float score = Float.NaN; + if (!Float.isNaN(maxScore)) { + score = (maxScore - j) * scoreFactor; + hit.score(score); + } + + hit.shard(shardTarget); + if (sortFields != null) { + Object[] rawSortValues = new Object[sortFields.length]; + DocValueFormat[] docValueFormats = new DocValueFormat[sortFields.length]; + for (int k = 0; k < sortFields.length; k++) { + SortField sortField = sortFields[k]; + if (sortField == SortField.FIELD_SCORE) { + hit.score(score); + rawSortValues[k] = score; + } else { + rawSortValues[k] = sortField.getReverse() ? numDocs * sortFieldFactors[k] - j : j; + } + docValueFormats[k] = DocValueFormat.RAW; + } + hit.sortValues(rawSortValues, docValueFormats); + } + hits[j] = hit; + priorityQueue.add(hit); + } + return hits; + } + + private static final class SearchHitComparator implements Comparator { + + private final SortField[] sortFields; + + SearchHitComparator(SortField[] sortFields) { + this.sortFields = sortFields; + } + + @Override + public int compare(SearchHit a, SearchHit b) { + if (sortFields == null) { + int scoreCompare = Float.compare(b.getScore(), a.getScore()); + if (scoreCompare != 0) { + return scoreCompare; + } + } else { + for (int i = 0; i < sortFields.length; i++) { + SortField sortField = sortFields[i]; + if (sortField == SortField.FIELD_SCORE) { + int scoreCompare = Float.compare(b.getScore(), a.getScore()); + if (scoreCompare != 0) { + return scoreCompare; + } + } else { + Integer aSortValue = (Integer) a.getRawSortValues()[i]; + Integer bSortValue = (Integer) b.getRawSortValues()[i]; + final int compare; + if (sortField.getReverse()) { + compare = Integer.compare(bSortValue, aSortValue); + } else { + compare = Integer.compare(aSortValue, bSortValue); + } + if (compare != 0) { + return compare; + } + } + } + } + SearchShardTarget aShard = a.getShard(); + SearchShardTarget bShard = b.getShard(); + int shardIdCompareTo = aShard.getShardId().compareTo(bShard.getShardId()); + if (shardIdCompareTo != 0) { + return shardIdCompareTo; + } + int clusterAliasCompareTo = aShard.getClusterAlias().compareTo(bShard.getClusterAlias()); + if (clusterAliasCompareTo != 0) { + return clusterAliasCompareTo; + } + return Integer.compare(a.docId(), b.docId()); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index 86f2dc620..489ebc520 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -590,7 +590,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); Map hit1DetailsForHit2 = hit1Details.get(1); - assertEquals(0.6666667, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.666, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); assertEquals("normalized scores: [1.0, 0.0, 1.0] combined to a final score: 0.6666667", hit1DetailsForHit2.get("description")); assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); @@ -627,7 +627,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { assertEquals(0, ((List) hit1DetailsForHit4.get("details")).size()); Map hit2DetailsForHit4 = hit4Details.get(1); - assertEquals(0.6666667, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.666, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); assertEquals("normalized scores: [0.0, 1.0, 1.0] combined to a final score: 0.6666667", hit2DetailsForHit4.get("description")); assertEquals(0, ((List) hit2DetailsForHit4.get("details")).size()); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 1d666d788..5761c5569 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -49,7 +49,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.processor.NormalizationProcessor; -import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import org.opensearch.search.sort.SortBuilder; @@ -1203,7 +1203,7 @@ protected void createSearchPipeline( if (addExplainResponseProcessor) { stringBuilderForContentBody.append(", \"response_processors\": [ ") .append("{\"") - .append(ExplainResponseProcessor.TYPE) + .append(ExplanationResponseProcessor.TYPE) .append("\": {}}") .append("]"); } From 72c0ac35a603bbb80131da6368c6632309961ca6 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 4 Nov 2024 09:52:17 -0800 Subject: [PATCH 07/11] Refactor classes and methods Signed-off-by: Martin Gaievski --- .../ExplanationResponseProcessor.java | 16 ++-- .../processor/NormalizationProcessor.java | 6 +- .../NormalizationProcessorWorkflow.java | 22 +++--- ...ithmeticMeanScoreCombinationTechnique.java | 2 +- ...eometricMeanScoreCombinationTechnique.java | 2 +- ...HarmonicMeanScoreCombinationTechnique.java | 2 +- .../processor/combination/ScoreCombiner.java | 47 +++++++----- .../explain/CombinedExplainDetails.java | 4 +- .../explain/ExplainableTechnique.java | 2 +- ...plainUtils.java => ExplainationUtils.java} | 12 +-- ...inDetails.java => ExplanationDetails.java} | 4 +- ...nResponse.java => ExplanationPayload.java} | 8 +- .../L2ScoreNormalizationTechnique.java | 6 +- .../MinMaxScoreNormalizationTechnique.java | 74 ++++++++++--------- .../normalization/ScoreNormalizer.java | 11 ++- ... => ExplanationPayloadProcessorTests.java} | 54 +++++++------- 16 files changed, 148 insertions(+), 124 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/explain/{ExplainUtils.java => ExplainationUtils.java} (90%) rename src/main/java/org/opensearch/neuralsearch/processor/explain/{ExplainDetails.java => ExplanationDetails.java} (74%) rename src/main/java/org/opensearch/neuralsearch/processor/explain/{ExplanationResponse.java => ExplanationPayload.java} (72%) rename src/test/java/org/opensearch/neuralsearch/processor/{ExplanationResponseProcessorTests.java => ExplanationPayloadProcessorTests.java} (90%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 4cfaf9837..33e971080 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -10,7 +10,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; @@ -24,7 +24,7 @@ import java.util.Objects; import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; -import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR; +import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR; /** * Processor to add explanation details to search response @@ -40,19 +40,21 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor { private final boolean ignoreFailure; @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { return processResponse(request, response, null); } @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { - if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) { + if (Objects.isNull(requestContext) + || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY))) + || requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) { return response; } - ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); - Map explainPayload = explanationResponse.getExplainPayload(); + ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); + Map explainPayload = explanationPayload.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { - Explanation processorExplanation = explanationResponse.getExplanation(); + Explanation processorExplanation = explanationPayload.getExplanation(); if (Objects.isNull(processorExplanation)) { return response; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 7f0314ef7..d2008ae97 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -66,9 +66,9 @@ public void process( */ @Override public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext, - final PipelineProcessingContext requestContext + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext ) { prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 118d0a25c..f5fb794a7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -26,9 +26,9 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -42,7 +42,7 @@ import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; /** @@ -123,11 +123,11 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs); - Map normalizationExplain = scoreNormalizer.explain( + Map normalizationExplain = scoreNormalizer.explain( queryTopDocs, (ExplainableTechnique) request.getNormalizationTechnique() ); - Map> combinationExplain = scoreCombiner.explain( + Map> combinationExplain = scoreCombiner.explain( queryTopDocs, request.getCombinationTechnique(), sortForQuery @@ -135,24 +135,24 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< Map> combinedExplain = new HashMap<>(); combinationExplain.forEach((searchShard, explainDetails) -> { - for (ExplainDetails explainDetail : explainDetails) { + for (ExplanationDetails explainDetail : explainDetails) { DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard); - ExplainDetails normalizedExplainDetails = normalizationExplain.get(docIdAtSearchShard); + ExplanationDetails normalizedExplanationDetails = normalizationExplain.get(docIdAtSearchShard); CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder() - .normalizationExplain(normalizedExplainDetails) + .normalizationExplain(normalizedExplanationDetails) .combinationExplain(explainDetail) .build(); combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails); } }); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(topLevelExplanationForTechniques) - .explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) + .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain)) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationPayload); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 1d31b4c31..4055d0377 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on arithmetic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 0dcd5c39c..7de4e0499 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on geometrical mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 4fd112bc5..f6c68bc7e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on harmonic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 8194ecf74..87f8135b9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -27,9 +27,9 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.processor.SearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument; /** * Abstracts combination of scores in query search results. @@ -318,40 +318,47 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon return new TotalHits(maxHits, totalHits); } - public Map> explain( + /** + * Explain the score combination technique for each document in the given queryTopDocs. + * @param queryTopDocs + * @param combinationTechnique + * @param sort + * @return a map of SearchShard and List of ExplainationDetails for each document + */ + public Map> explain( final List queryTopDocs, final ScoreCombinationTechnique combinationTechnique, final Sort sort ) { // In case of duplicate keys, keep the first value - HashMap> explanations = new HashMap<>(); + Map> explanations = new HashMap<>(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { - for (Map.Entry> docIdAtSearchShardExplainDetailsEntry : explainByShard( - combinationTechnique, - compoundQueryTopDocs, - sort - ).entrySet()) { - explanations.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); - } + explanations.putIfAbsent( + compoundQueryTopDocs.getSearchShard(), + explainByShard(combinationTechnique, compoundQueryTopDocs, sort) + ); } return explanations; } - private Map> explainByShard( + private List explainByShard( final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs, - Sort sort + final Sort sort ) { if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { - return Map.of(); + return List.of(); } - // - create map of normalized scores results returned from the single shard + // create map of normalized scores results returned from the single shard Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs()); + // combine scores Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + // sort combined scores as per sorting criteria - either score desc or field sorting Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - List listOfExplainsForShard = sortedDocsIds.stream() + + List listOfExplanations = sortedDocsIds.stream() .map( docId -> getScoreCombinationExplainDetailsForDocument( docId, @@ -360,13 +367,13 @@ private Map> explainByShard( ) ) .toList(); - return Map.of(compoundQueryTopDocs.getSearchShard(), listOfExplainsForShard); + return listOfExplanations; } private Collection getSortedDocsIds( - CompoundTopDocs compoundQueryTopDocs, - Sort sort, - Map combinedNormalizedScoresByDocId + final CompoundTopDocs compoundQueryTopDocs, + final Sort sort, + final Map combinedNormalizedScoresByDocId ) { Collection sortedDocsIds; if (sort != null) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java index 4a9793fd4..dd8a2a9c3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java @@ -15,6 +15,6 @@ @Builder @Getter public class CombinedExplainDetails { - private ExplainDetails normalizationExplain; - private ExplainDetails combinationExplain; + private ExplanationDetails normalizationExplain; + private ExplanationDetails combinationExplain; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java index 6f8dfcf1e..cc2fab6c6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java @@ -28,7 +28,7 @@ default String describe() { * @param queryTopDocs collection of CompoundTopDocs for each shard result * @return map of document per shard and corresponding explanation object */ - default Map explain(final List queryTopDocs) { + default Map explain(final List queryTopDocs) { return Map.of(); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java similarity index 90% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java index 1eca4232f..9e9fd4c3a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java @@ -17,7 +17,7 @@ /** * Utility class for explain functionality */ -public class ExplainUtils { +public class ExplainationUtils { /** * Creates map of DocIdAtQueryPhase to String containing source and normalized scores @@ -25,16 +25,16 @@ public class ExplainUtils { * @param sourceScores map of DocIdAtQueryPhase to source scores * @return map of DocIdAtQueryPhase to String containing source and normalized scores */ - public static Map getDocIdAtQueryForNormalization( + public static Map getDocIdAtQueryForNormalization( final Map> normalizedScores, final Map> sourceScores ) { - Map explain = sourceScores.entrySet() + Map explain = sourceScores.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> { List srcScores = entry.getValue(); List normScores = normalizedScores.get(entry.getKey()); - return new ExplainDetails( + return new ExplanationDetails( normScores.stream().reduce(0.0f, Float::max), String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores) ); @@ -49,13 +49,13 @@ public static Map getDocIdAtQueryForNormaliz * @param normalizedScoresPerDoc * @return */ - public static ExplainDetails getScoreCombinationExplainDetailsForDocument( + public static ExplanationDetails getScoreCombinationExplainDetailsForDocument( final Integer docId, final Map combinedNormalizedScoresByDocId, final float[] normalizedScoresPerDoc ) { float combinedScore = combinedNormalizedScoresByDocId.get(docId); - return new ExplainDetails( + return new ExplanationDetails( docId, combinedScore, String.format( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java similarity index 74% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index ca83cc436..594bc4299 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -10,8 +10,8 @@ * @param value * @param description */ -public record ExplainDetails(int docId, float value, String description) { - public ExplainDetails(float value, String description) { +public record ExplanationDetails(int docId, float value, String description) { + public ExplanationDetails(float value, String description) { this(-1, value, description); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java similarity index 72% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java index f14050214..a1206a1a1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java @@ -17,11 +17,11 @@ @AllArgsConstructor @Builder @Getter -public class ExplanationResponse { - Explanation explanation; - Map explainPayload; +public class ExplanationPayload { + private final Explanation explanation; + private final Map explainPayload; - public enum ExplanationType { + public enum PayloadType { NORMALIZATION_PROCESSOR } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 5c5436564..ca68bf563 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -17,10 +17,10 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on L2 method @@ -64,7 +64,7 @@ public String describe() { } @Override - public Map explain(List queryTopDocs) { + public Map explain(List queryTopDocs) { Map> normalizedScores = new HashMap<>(); Map> sourceScores = new HashMap<>(); List normsPerSubquery = getL2Norm(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 63efb4332..e3487cbcb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -20,10 +20,10 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on min-max method @@ -44,19 +44,7 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech */ @Override public void normalize(final List queryTopDocs) { - int numOfSubqueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getTopDocs().size() > 0) - .findAny() - .get() - .getTopDocs() - .size(); - // get min scores for each sub query - float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); - - // get max scores for each sub query - float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); - + MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); // do normalization using actual score and min and max scores for corresponding sub query for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { @@ -66,35 +54,36 @@ public void normalize(final List queryTopDocs) { for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - scoreDoc.score = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); + scoreDoc.score = normalizeSingleScore( + scoreDoc.score, + minMaxScores.minScoresPerSubquery()[j], + minMaxScores.maxScoresPerSubquery()[j] + ); } } } } + private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { + int numOfSubqueries = getNumOfSubqueries(queryTopDocs); + // get min scores for each sub query + float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); + // get max scores for each sub query + float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); + return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery); + } + @Override public String describe() { return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); } @Override - public Map explain(final List queryTopDocs) { + public Map explain(final List queryTopDocs) { + MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); + Map> normalizedScores = new HashMap<>(); Map> sourceScores = new HashMap<>(); - - int numOfSubqueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) - .findAny() - .get() - .getTopDocs() - .size(); - // get min scores for each sub query - float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); - - // get max scores for each sub query - float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); - for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -104,17 +93,30 @@ public Map explain(final List new ArrayList<>()).add(normalizedScore); sourceScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(scoreDoc.score); scoreDoc.score = normalizedScore; } } } - return getDocIdAtQueryForNormalization(normalizedScores, sourceScores); } + private int getNumOfSubqueries(final List queryTopDocs) { + return queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) + .findAny() + .get() + .getTopDocs() + .size(); + } + private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { float[] maxScores = new float[numOfSubqueries]; Arrays.fill(maxScores, Float.MIN_VALUE); @@ -165,4 +167,10 @@ private float normalizeSingleScore(final float score, final float minScore, fina float normalizedScore = (score - minScore) / (maxScore - minScore); return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; } + + /** + * Result class to hold min and max scores for each sub query + */ + private record MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) { + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 2dcf5f768..67a17fda2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -10,7 +10,7 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; public class ScoreNormalizer { @@ -30,7 +30,14 @@ private boolean canQueryResultsBeNormalized(final List queryTop return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0); } - public Map explain( + /** + * Explain normalized scores based on input normalization technique. Does not mutate input object. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * @param queryTopDocs + * @param scoreNormalizationTechnique + * @return map of doc id to explanation details + */ + public Map explain( final List queryTopDocs, final ExplainableTechnique scoreNormalizationTechnique ) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java similarity index 90% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java index ce2df0b13..fe0099c87 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java @@ -16,8 +16,8 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -38,7 +38,7 @@ import static org.mockito.Mockito.mock; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationResponseProcessorTests extends OpenSearchTestCase { +public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -130,27 +130,27 @@ public void testParsingOfExplanations_whenResponseHasExplanations_thenSuccessful SearchShard.createSearchShard(searchHitArray[0].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) .build() ), SearchShard.createSearchShard(searchHitArray[1].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) .build() ) ); - Map explainPayload = Map.of( - ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(generalExplanation) .explainPayload(explainPayload) .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( @@ -216,27 +216,27 @@ public void testParsingOfExplanations_whenFieldSortingAndExplanations_thenSucces SearchShard.createSearchShard(searchHitArray[0].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) .build() ), SearchShard.createSearchShard(searchHitArray[1].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) .build() ) ); - Map explainPayload = Map.of( - ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(generalExplanation) .explainPayload(explainPayload) .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( @@ -300,27 +300,27 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces SearchShard.createSearchShard(searchHitArray[0].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) .build() ), SearchShard.createSearchShard(searchHitArray[1].getShard()), List.of( CombinedExplainDetails.builder() - .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) .build() ) ); - Map explainPayload = Map.of( - ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationResponse explanationResponse = ExplanationResponse.builder() + ExplanationPayload explanationPayload = ExplanationPayload.builder() .explanation(generalExplanation) .explainPayload(explainPayload) .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( From 9830ab3c1ccd4384b77b4fb61293a0f1a13d52fd Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 12 Nov 2024 18:34:35 -0800 Subject: [PATCH 08/11] Change response format, switch to hierarchical structure Signed-off-by: Martin Gaievski --- .../neuralsearch/plugin/NeuralSearch.java | 2 +- .../ExplanationResponseProcessor.java | 57 +- .../NormalizationProcessorWorkflow.java | 39 +- ...ithmeticMeanScoreCombinationTechnique.java | 2 +- ...eometricMeanScoreCombinationTechnique.java | 2 +- ...HarmonicMeanScoreCombinationTechnique.java | 2 +- .../processor/combination/ScoreCombiner.java | 15 +- ...s.java => CombinedExplanationDetails.java} | 6 +- .../processor/explain/ExplainationUtils.java | 99 ---- .../processor/explain/ExplanationDetails.java | 15 +- .../processor/explain/ExplanationPayload.java | 2 - .../processor/explain/ExplanationUtils.java | 53 ++ .../L2ScoreNormalizationTechnique.java | 8 +- .../MinMaxScoreNormalizationTechnique.java | 8 +- .../neuralsearch/query/HybridQueryWeight.java | 2 +- .../ExplanationPayloadProcessorTests.java | 308 ++++------ .../query/HybridQueryExplainIT.java | 559 ++++++++++++++++++ .../neuralsearch/query/HybridQueryIT.java | 306 ---------- .../neuralsearch/query/HybridQuerySortIT.java | 146 +++-- .../neuralsearch/BaseNeuralSearchIT.java | 4 + .../neuralsearch/util/TestUtils.java | 6 +- 21 files changed, 875 insertions(+), 766 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/explain/{CombinedExplainDetails.java => CombinedExplanationDetails.java} (68%) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 8deabd141..1350a7963 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -82,7 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); - public static final String EXPLAIN_RESPONSE_KEY = "explain_response"; + public static final String EXPLANATION_RESPONSE_KEY = "explanation_response"; @Override public Collection createComponents( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 33e971080..74ae0621a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -9,7 +9,8 @@ import org.apache.lucene.search.Explanation; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -23,7 +24,7 @@ import java.util.Map; import java.util.Objects; -import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR; /** @@ -45,19 +46,19 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { + public SearchResponse processResponse( + final SearchRequest request, + final SearchResponse response, + final PipelineProcessingContext requestContext + ) { if (Objects.isNull(requestContext) - || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY))) - || requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) { + || (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY))) + || requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) { return response; } - ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); + ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY); Map explainPayload = explanationPayload.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { - Explanation processorExplanation = explanationPayload.getExplanation(); - if (Objects.isNull(processorExplanation)) { - return response; - } SearchHits searchHits = response.getHits(); SearchHit[] searchHitsArray = searchHits.getHits(); // create a map of searchShard and list of indexes of search hit objects in search hits array @@ -73,29 +74,33 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map) { @SuppressWarnings("unchecked") - Map> combinedExplainDetails = (Map< + Map> combinedExplainDetails = (Map< SearchShard, - List>) explainPayload.get(NORMALIZATION_PROCESSOR); + List>) explainPayload.get(NORMALIZATION_PROCESSOR); for (SearchHit searchHit : searchHitsArray) { SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; - CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); - Explanation normalizedExplanation = Explanation.match( - combinedExplainDetail.getNormalizationExplain().value(), - combinedExplainDetail.getNormalizationExplain().description() - ); - Explanation combinedExplanation = Explanation.match( - combinedExplainDetail.getCombinationExplain().value(), - combinedExplainDetail.getCombinationExplain().description() - ); - + CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); + Explanation queryLevelExplanation = searchHit.getExplanation(); + ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations(); + ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations(); + Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length]; + for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { + normalizedExplanation[i] = Explanation.match( + // normalized score + normalizationExplanation.scoreDetails().get(i).getKey(), + // description of normalized score + normalizationExplanation.scoreDetails().get(i).getValue(), + // shard level details + queryLevelExplanation.getDetails()[i] + ); + } Explanation finalExplanation = Explanation.match( searchHit.getScore(), - processorExplanation.getDescription(), - normalizedExplanation, - combinedExplanation, - searchHit.getExplanation() + // combination level explanation is always a single detail + combinationExplanation.scoreDetails().get(0).getValue(), + normalizedExplanation ); searchHit.explanation(finalExplanation); explainsByShardCount.put(searchShard, explanationIndexByShard); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index f5fb794a7..1a958676a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -4,7 +4,6 @@ */ package org.opensearch.neuralsearch.processor; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -14,7 +13,6 @@ import java.util.Optional; import java.util.stream.Collectors; -import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Sort; @@ -24,7 +22,7 @@ import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; -import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; @@ -40,9 +38,8 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; /** @@ -113,16 +110,9 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< if (!request.isExplain()) { return; } - Explanation topLevelExplanationForTechniques = topLevelExpalantionForCombinedScore( - (ExplainableTechnique) request.getNormalizationTechnique(), - (ExplainableTechnique) request.getCombinationTechnique() - ); - // build final result object with all explain related information if (Objects.nonNull(request.getPipelineProcessingContext())) { - Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs); - Map normalizationExplain = scoreNormalizer.explain( queryTopDocs, (ExplainableTechnique) request.getNormalizationTechnique() @@ -132,27 +122,22 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< request.getCombinationTechnique(), sortForQuery ); - Map> combinedExplain = new HashMap<>(); - - combinationExplain.forEach((searchShard, explainDetails) -> { - for (ExplanationDetails explainDetail : explainDetails) { - DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard); - ExplanationDetails normalizedExplanationDetails = normalizationExplain.get(docIdAtSearchShard); - CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder() - .normalizationExplain(normalizedExplanationDetails) - .combinationExplain(explainDetail) + Map> combinedExplanations = combinationExplain.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> { + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), entry.getKey()); + return CombinedExplanationDetails.builder() + .normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)) + .combinationExplanations(explainDetail) .build(); - combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails); - } - }); + }).collect(Collectors.toList()))); ExplanationPayload explanationPayload = ExplanationPayload.builder() - .explanation(topLevelExplanationForTechniques) - .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain)) + .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations)) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationPayload); + pipelineProcessingContext.setAttribute(EXPLANATION_RESPONSE_KEY, explanationPayload); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 4055d0377..5ad79e75a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on arithmetic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 7de4e0499..b5bdabb43 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on geometrical mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index f6c68bc7e..eeb5950f1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on harmonic mean method diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 87f8135b9..cbc3f485b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -8,6 +8,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.Objects; @@ -16,6 +17,7 @@ import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.ScoreDoc; @@ -27,10 +29,9 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument; - /** * Abstracts combination of scores in query search results. */ @@ -360,10 +361,14 @@ private List explainByShard( List listOfExplanations = sortedDocsIds.stream() .map( - docId -> getScoreCombinationExplainDetailsForDocument( + docId -> new ExplanationDetails( docId, - combinedNormalizedScoresByDocId, - normalizedScoresPerDoc.get(docId) + List.of( + Pair.of( + combinedNormalizedScoresByDocId.get(docId), + String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique) scoreCombinationTechnique).describe()) + ) + ) ) ) .toList(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java similarity index 68% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java index dd8a2a9c3..c2e1b61e5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java @@ -14,7 +14,7 @@ @AllArgsConstructor @Builder @Getter -public class CombinedExplainDetails { - private ExplanationDetails normalizationExplain; - private ExplanationDetails combinationExplain; +public class CombinedExplanationDetails { + private ExplanationDetails normalizationExplanations; + private ExplanationDetails combinationExplanations; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java deleted file mode 100644 index 9e9fd4c3a..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainationUtils.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.processor.explain; - -import org.apache.lucene.search.Explanation; - -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; - -/** - * Utility class for explain functionality - */ -public class ExplainationUtils { - - /** - * Creates map of DocIdAtQueryPhase to String containing source and normalized scores - * @param normalizedScores map of DocIdAtQueryPhase to normalized scores - * @param sourceScores map of DocIdAtQueryPhase to source scores - * @return map of DocIdAtQueryPhase to String containing source and normalized scores - */ - public static Map getDocIdAtQueryForNormalization( - final Map> normalizedScores, - final Map> sourceScores - ) { - Map explain = sourceScores.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> { - List srcScores = entry.getValue(); - List normScores = normalizedScores.get(entry.getKey()); - return new ExplanationDetails( - normScores.stream().reduce(0.0f, Float::max), - String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores) - ); - })); - return explain; - } - - /** - * Return the detailed score combination explain for the single document - * @param docId - * @param combinedNormalizedScoresByDocId - * @param normalizedScoresPerDoc - * @return - */ - public static ExplanationDetails getScoreCombinationExplainDetailsForDocument( - final Integer docId, - final Map combinedNormalizedScoresByDocId, - final float[] normalizedScoresPerDoc - ) { - float combinedScore = combinedNormalizedScoresByDocId.get(docId); - return new ExplanationDetails( - docId, - combinedScore, - String.format( - Locale.ROOT, - "normalized scores: %s combined to a final score: %s", - Arrays.toString(normalizedScoresPerDoc), - combinedScore - ) - ); - } - - /** - * Creates a string describing the combination technique and its parameters - * @param techniqueName the name of the combination technique - * @param weights the weights used in the combination technique - * @return a string describing the combination technique and its parameters - */ - public static String describeCombinationTechnique(final String techniqueName, final List weights) { - return String.format(Locale.ROOT, "combination [%s] with optional parameter [%s]: %s", techniqueName, PARAM_NAME_WEIGHTS, weights); - } - - /** - * Creates an Explanation object for the top-level explanation of the combined score - * @param explainableNormalizationTechnique the normalization technique used - * @param explainableCombinationTechnique the combination technique used - * @return an Explanation object for the top-level explanation of the combined score - */ - public static Explanation topLevelExpalantionForCombinedScore( - final ExplainableTechnique explainableNormalizationTechnique, - final ExplainableTechnique explainableCombinationTechnique - ) { - String explanationDetailsMessage = String.format( - Locale.ROOT, - "combined score with techniques: %s, %s", - explainableNormalizationTechnique.describe(), - explainableCombinationTechnique.describe() - ); - - return Explanation.match(0.0f, explanationDetailsMessage); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index 594bc4299..fe009f383 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -4,14 +4,19 @@ */ package org.opensearch.neuralsearch.processor.explain; +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; + /** * DTO class to store value and description for explain details. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. - * @param value - * @param description + * @param docId iterator based id of the document + * @param scoreDetails list of score details for the document, each Pair object contains score and description of the score */ -public record ExplanationDetails(int docId, float value, String description) { - public ExplanationDetails(float value, String description) { - this(-1, value, description); +public record ExplanationDetails(int docId, List> scoreDetails) { + + public ExplanationDetails(List> scoreDetails) { + this(-1, scoreDetails); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java index a1206a1a1..708f655c0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java @@ -7,7 +7,6 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; -import org.apache.lucene.search.Explanation; import java.util.Map; @@ -18,7 +17,6 @@ @Builder @Getter public class ExplanationPayload { - private final Explanation explanation; private final Map explainPayload; public enum PayloadType { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java new file mode 100644 index 000000000..499ce77cf --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Utility class for explain functionality + */ +public class ExplanationUtils { + + /** + * Creates map of DocIdAtQueryPhase to String containing source and normalized scores + * @param normalizedScores map of DocIdAtQueryPhase to normalized scores + * @return map of DocIdAtQueryPhase to String containing source and normalized scores + */ + public static Map getDocIdAtQueryForNormalization( + final Map> normalizedScores, + final ExplainableTechnique technique + ) { + Map explain = normalizedScores.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> { + List normScores = normalizedScores.get(entry.getKey()); + List> explanations = normScores.stream() + .map(score -> Pair.of(score, String.format(Locale.ROOT, "%s normalization of:", technique.describe()))) + .collect(Collectors.toList()); + return new ExplanationDetails(explanations); + })); + return explain; + } + + /** + * Creates a string describing the combination technique and its parameters + * @param techniqueName the name of the combination technique + * @param weights the weights used in the combination technique + * @return a string describing the combination technique and its parameters + */ + public static String describeCombinationTechnique(final String techniqueName, final List weights) { + return Optional.ofNullable(weights) + .filter(w -> !w.isEmpty()) + .map(w -> String.format(Locale.ROOT, "%s, weights %s", techniqueName, weights)) + .orElse(String.format(Locale.ROOT, "%s", techniqueName)); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index ca68bf563..e7fbf658c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -20,7 +20,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on L2 method @@ -60,13 +60,12 @@ public void normalize(final List queryTopDocs) { @Override public String describe() { - return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); + return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); } @Override public Map explain(List queryTopDocs) { Map> normalizedScores = new HashMap<>(); - Map> sourceScores = new HashMap<>(); List normsPerSubquery = getL2Norm(queryTopDocs); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { @@ -80,12 +79,11 @@ public Map explain(List DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore); - sourceScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(scoreDoc.score); scoreDoc.score = normalizedScore; } } } - return getDocIdAtQueryForNormalization(normalizedScores, sourceScores); + return getDocIdAtQueryForNormalization(normalizedScores, this); } private List getL2Norm(final List queryTopDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index e3487cbcb..3ca538f4e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -23,7 +23,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on min-max method @@ -75,7 +75,7 @@ private MinMaxScores getMinMaxScoresResult(final List queryTopD @Override public String describe() { - return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); + return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); } @Override @@ -83,7 +83,6 @@ public Map explain(final List> normalizedScores = new HashMap<>(); - Map> sourceScores = new HashMap<>(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -99,12 +98,11 @@ public Map explain(final List new ArrayList<>()).add(normalizedScore); - sourceScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(scoreDoc.score); scoreDoc.score = normalizedScore; } } } - return getDocIdAtQueryForNormalization(normalizedScores, sourceScores); + return getDocIdAtQueryForNormalization(normalizedScores, this); } private int getNumOfSubqueries(final List queryTopDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index c0c91be6e..bad8fda74 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -164,7 +164,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio } } if (match) { - final String desc = "base scores from subqueries:"; + final String desc = "combined score of:"; return Explanation.match(max, desc, subsOnMatch); } else { return Explanation.noMatch("no matching clause", subsOnNoMatch); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java index fe0099c87..e47ea43d2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java @@ -6,6 +6,7 @@ import lombok.SneakyThrows; import org.apache.commons.lang3.Range; +import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; @@ -15,7 +16,7 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.search.DocValueFormat; @@ -85,72 +86,20 @@ public void testParsingOfExplanations_whenResponseHasExplanations_thenSuccessful // Setup ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); SearchRequest searchRequest = mock(SearchRequest.class); - - int numResponses = 1; - int numIndices = 2; - Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); - Map.Entry entry = indicesIterator.next(); - String clusterAlias = entry.getKey(); - Index[] indices = entry.getValue(); - - int requestedSize = 2; - PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(null)); - TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); - TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); - - final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; - int scoreFactor = randomIntBetween(1, numResponses); - float maxScore = numDocs * scoreFactor; - - SearchHit[] searchHitArray = randomSearchHitArray( - numDocs, - numResponses, - clusterAlias, - indices, - maxScore, - scoreFactor, - null, - priorityQueue - ); - for (SearchHit searchHit : searchHitArray) { - Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); - searchHit.explanation(explanation); - } - - SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(numResponses, TotalHits.Relation.EQUAL_TO), 1.0f); - + float maxScore = 1.0f; + SearchHits searchHits = getSearchHits(maxScore); SearchResponse searchResponse = getSearchResponse(searchHits); - PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - Explanation generalExplanation = Explanation.match( - maxScore, - "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" - ); - Map> combinedExplainDetails = Map.of( - SearchShard.createSearchShard(searchHitArray[0].getShard()), - List.of( - CombinedExplainDetails.builder() - .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) - .build() - ), - SearchShard.createSearchShard(searchHitArray[1].getShard()), - List.of( - CombinedExplainDetails.builder() - .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) - .build() - ) - ); + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); Map explainPayload = Map.of( ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationPayload explanationPayload = ExplanationPayload.builder() - .explanation(generalExplanation) - .explainPayload(explainPayload) - .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + pipelineProcessingContext.setAttribute( + org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, + explanationPayload + ); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( @@ -169,74 +118,32 @@ public void testParsingOfExplanations_whenFieldSortingAndExplanations_thenSucces ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); SearchRequest searchRequest = mock(SearchRequest.class); - int numResponses = 1; - int numIndices = 2; - Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); - Map.Entry entry = indicesIterator.next(); - String clusterAlias = entry.getKey(); - Index[] indices = entry.getValue(); + float maxScore = 1.0f; + SearchHits searchHitsWithoutSorting = getSearchHits(maxScore); + for (SearchHit searchHit : searchHitsWithoutSorting.getHits()) { + Explanation explanation = Explanation.match(1.0f, "combined score of:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); final SortField[] sortFields = new SortField[] { new SortField("random-text-field-1", SortField.Type.INT, randomBoolean()), new SortField("random-text-field-2", SortField.Type.STRING, randomBoolean()) }; - - int requestedSize = 2; - PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); - TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); - TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); - - final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; - int scoreFactor = randomIntBetween(1, numResponses); - float maxScore = Float.NaN; - - SearchHit[] searchHitArray = randomSearchHitArray( - numDocs, - numResponses, - clusterAlias, - indices, - maxScore, - scoreFactor, - sortFields, - priorityQueue - ); - for (SearchHit searchHit : searchHitArray) { - Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); - searchHit.explanation(explanation); - } - - SearchHits searchHits = new SearchHits(searchHitArray, totalHits, maxScore, sortFields, null, null); + SearchHits searchHits = new SearchHits(searchHitsWithoutSorting.getHits(), totalHits, maxScore, sortFields, null, null); SearchResponse searchResponse = getSearchResponse(searchHits); PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - Explanation generalExplanation = Explanation.match( - maxScore, - "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" - ); - Map> combinedExplainDetails = Map.of( - SearchShard.createSearchShard(searchHitArray[0].getShard()), - List.of( - CombinedExplainDetails.builder() - .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) - .build() - ), - SearchShard.createSearchShard(searchHitArray[1].getShard()), - List.of( - CombinedExplainDetails.builder() - .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) - .build() - ) - ); + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); Map explainPayload = Map.of( ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplainDetails ); - ExplanationPayload explanationPayload = ExplanationPayload.builder() - .explanation(generalExplanation) - .explainPayload(explainPayload) - .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + pipelineProcessingContext.setAttribute( + org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, + explanationPayload + ); // Act SearchResponse processedResponse = explanationResponseProcessor.processResponse( @@ -255,22 +162,51 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); SearchRequest searchRequest = mock(SearchRequest.class); + float maxScore = 1.0f; + + SearchHits searchHits = getSearchHits(maxScore); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + pipelineProcessingContext.setAttribute( + org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, + explanationPayload + ); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + private static SearchHits getSearchHits(float maxScore) { int numResponses = 1; int numIndices = 2; Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); Map.Entry entry = indicesIterator.next(); String clusterAlias = entry.getKey(); Index[] indices = entry.getValue(); - final SortField[] sortFields = new SortField[] { SortField.FIELD_SCORE }; int requestedSize = 2; - PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(null)); TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; int scoreFactor = randomIntBetween(1, numResponses); - float maxScore = Float.NaN; SearchHit[] searchHitArray = randomSearchHitArray( numDocs, @@ -279,58 +215,16 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces indices, maxScore, scoreFactor, - sortFields, + null, priorityQueue ); for (SearchHit searchHit : searchHitArray) { - Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + Explanation explanation = Explanation.match(1.0f, "combined score of:", Explanation.match(1.0f, "field1:[0 TO 100]")); searchHit.explanation(explanation); } - SearchHits searchHits = new SearchHits(searchHitArray, totalHits, maxScore, sortFields, null, null); - - SearchResponse searchResponse = getSearchResponse(searchHits); - - PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - Explanation generalExplanation = Explanation.match( - maxScore, - "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" - ); - Map> combinedExplainDetails = Map.of( - SearchShard.createSearchShard(searchHitArray[0].getShard()), - List.of( - CombinedExplainDetails.builder() - .normalizationExplain(new ExplanationDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) - .combinationExplain(new ExplanationDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) - .build() - ), - SearchShard.createSearchShard(searchHitArray[1].getShard()), - List.of( - CombinedExplainDetails.builder() - .normalizationExplain(new ExplanationDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) - .combinationExplain(new ExplanationDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) - .build() - ) - ); - Map explainPayload = Map.of( - ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, - combinedExplainDetails - ); - ExplanationPayload explanationPayload = ExplanationPayload.builder() - .explanation(generalExplanation) - .explainPayload(explainPayload) - .build(); - pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationPayload); - - // Act - SearchResponse processedResponse = explanationResponseProcessor.processResponse( - searchRequest, - searchResponse, - pipelineProcessingContext - ); - - // Assert - assertOnExplanationResults(processedResponse, maxScore); + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(numResponses, TotalHits.Relation.EQUAL_TO), maxScore); + return searchHits; } private static SearchResponse getSearchResponse(SearchHits searchHits) { @@ -356,45 +250,67 @@ private static SearchResponse getSearchResponse(SearchHits searchHits) { return searchResponse; } + private static Map> getCombinedExplainDetails(SearchHits searchHits) { + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHits.getHits()[0].getShard()), + List.of( + CombinedExplanationDetails.builder() + .normalizationExplanations(new ExplanationDetails(List.of(Pair.of(1.0f, "min_max normalization of:")))) + .combinationExplanations(new ExplanationDetails(List.of(Pair.of(0.5f, "arithmetic_mean combination of:")))) + .build() + ), + SearchShard.createSearchShard(searchHits.getHits()[1].getShard()), + List.of( + CombinedExplanationDetails.builder() + .normalizationExplanations(new ExplanationDetails(List.of(Pair.of(0.5f, "min_max normalization of:")))) + .combinationExplanations(new ExplanationDetails(List.of(Pair.of(0.25f, "arithmetic_mean combination of:")))) + .build() + ) + ); + return combinedExplainDetails; + } + private static void assertOnExplanationResults(SearchResponse processedResponse, float maxScore) { assertNotNull(processedResponse); - Explanation explanationHit1 = processedResponse.getHits().getHits()[0].getExplanation(); - assertNotNull(explanationHit1); - assertEquals( - "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]", - explanationHit1.getDescription() - ); - assertEquals(maxScore, (float) explanationHit1.getValue(), DELTA_FOR_SCORE_ASSERTION); + Explanation hit1TopLevelExplanation = processedResponse.getHits().getHits()[0].getExplanation(); + assertNotNull(hit1TopLevelExplanation); + assertEquals("arithmetic_mean combination of:", hit1TopLevelExplanation.getDescription()); + assertEquals(maxScore, (float) hit1TopLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation[] hit1SecondLevelDetails = hit1TopLevelExplanation.getDetails(); + assertEquals(1, hit1SecondLevelDetails.length); + assertEquals("min_max normalization of:", hit1SecondLevelDetails[0].getDescription()); + assertEquals(1.0f, (float) hit1SecondLevelDetails[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertNotNull(hit1SecondLevelDetails[0].getDetails()); + assertEquals(1, hit1SecondLevelDetails[0].getDetails().length); + Explanation hit1ShardLevelExplanation = hit1SecondLevelDetails[0].getDetails()[0]; + + assertEquals(1.0f, (float) hit1ShardLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals("field1:[0 TO 100]", hit1ShardLevelExplanation.getDescription()); - Explanation[] detailsHit1 = explanationHit1.getDetails(); - assertEquals(3, detailsHit1.length); - assertEquals("source scores: [1.0] normalized to scores: [0.5]", detailsHit1[0].getDescription()); - assertEquals(1.0f, (float) detailsHit1[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + Explanation hit2TopLevelExplanation = processedResponse.getHits().getHits()[1].getExplanation(); + assertNotNull(hit2TopLevelExplanation); + assertEquals("arithmetic_mean combination of:", hit2TopLevelExplanation.getDescription()); + assertEquals(0.0f, (float) hit2TopLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); - assertEquals("normalized scores: [0.5] combined to a final score: 0.5", detailsHit1[1].getDescription()); - assertEquals(0.5f, (float) detailsHit1[1].getValue(), DELTA_FOR_SCORE_ASSERTION); + Explanation[] hit2SecondLevelDetails = hit2TopLevelExplanation.getDetails(); + assertEquals(1, hit2SecondLevelDetails.length); + assertEquals("min_max normalization of:", hit2SecondLevelDetails[0].getDescription()); + assertEquals(.5f, (float) hit2SecondLevelDetails[0].getValue(), DELTA_FOR_SCORE_ASSERTION); - assertEquals("base scores from subqueries:", detailsHit1[2].getDescription()); - assertEquals(1.0f, (float) detailsHit1[2].getValue(), DELTA_FOR_SCORE_ASSERTION); + assertNotNull(hit2SecondLevelDetails[0].getDetails()); + assertEquals(1, hit2SecondLevelDetails[0].getDetails().length); + Explanation hit2ShardLevelExplanation = hit2SecondLevelDetails[0].getDetails()[0]; + + assertEquals(1.0f, (float) hit2ShardLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals("field1:[0 TO 100]", hit2ShardLevelExplanation.getDescription()); Explanation explanationHit2 = processedResponse.getHits().getHits()[1].getExplanation(); assertNotNull(explanationHit2); - assertEquals( - "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]", - explanationHit2.getDescription() - ); + assertEquals("arithmetic_mean combination of:", explanationHit2.getDescription()); assertTrue(Range.of(0.0f, maxScore).contains((float) explanationHit2.getValue())); - Explanation[] detailsHit2 = explanationHit2.getDetails(); - assertEquals(3, detailsHit2.length); - assertEquals("source scores: [0.5] normalized to scores: [0.25]", detailsHit2[0].getDescription()); - assertEquals(.5f, (float) detailsHit2[0].getValue(), DELTA_FOR_SCORE_ASSERTION); - - assertEquals("normalized scores: [0.25] combined to a final score: 0.25", detailsHit2[1].getDescription()); - assertEquals(.25f, (float) detailsHit2[1].getValue(), DELTA_FOR_SCORE_ASSERTION); - - assertEquals("base scores from subqueries:", detailsHit2[2].getDescription()); - assertEquals(1.0f, (float) detailsHit2[2].getValue(), DELTA_FOR_SCORE_ASSERTION); } private static Map randomRealisticIndices(int numIndices, int numClusters) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java new file mode 100644 index 000000000..a7656912c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -0,0 +1,559 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import com.google.common.primitives.Floats; +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.getMaxScore; +import static org.opensearch.neuralsearch.util.TestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.TestUtils.getTotalHits; + +public class HybridQueryExplainIT extends BaseNeuralSearchIT { + private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index"; + private static final String TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME = "test-hybrid-multi-doc-nested-fields-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME = "test-hybrid-multi-doc-index"; + + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; + public static final String NORMALIZATION_TECHNIQUE_L2 = "l2"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + @SneakyThrows + public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map topLevelExplanationsHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(topLevelExplanationsHit1); + assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "arithmetic_mean combination of:"; + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); + List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); + assertEquals(1, normalizationExplanationHit1.size()); + Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("sum of:", explanationsHit1.get("description")); + assertEquals(0.754f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, ((List) explanationsHit1.get("details")).size()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map topLevelExplanationsHit2 = (Map) searchHit2.get("_explanation"); + assertNotNull(topLevelExplanationsHit2); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); + List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); + assertEquals(1, normalizationExplanationHit2.size()); + + Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); + assertEquals(1.0, hit1DetailsForHit2.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals(0.287f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit2.get("description")); + assertEquals(1, getListOfValues(explanationsHit2, "details").size()); + + Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); + assertEquals(0.287f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2Details.get("description")); + assertEquals(3, getListOfValues(explanationsHit2Details, "details").size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(1); + Map topLevelExplanationsHit3 = (Map) searchHit3.get("_explanation"); + assertNotNull(topLevelExplanationsHit3); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); + List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); + assertEquals(1, normalizationExplanationHit3.size()); + + Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); + assertEquals(1.0, hit1DetailsForHit3.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit3.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); + + Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); + assertEquals(0.287f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit3.get("description")); + assertEquals(1, getListOfValues(explanationsHit3, "details").size()); + + Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); + assertEquals(0.287f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); + assertEquals(3, getListOfValues(explanationsHit3Details, "details").size()); + } finally { + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_TECHNIQUE_L2, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), + true + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain, hit 1 + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "arithmetic_mean, weights [0.3, 0.7] combination of:"; + assertEquals(expectedTopLevelDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + // two sub-queries meaning we do have two detail objects with separate query level details + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit1.get("description")); + assertTrue((double) explanationsHit1.get("value") > 0.5f); + assertEquals(0, ((List) explanationsHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); + assertEquals("l2 normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals("within top 10", explanationsHit2.get("description")); + assertTrue((double) explanationsHit2.get("value") > 0.0f); + assertEquals(0, ((List) explanationsHit2.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = (Map) searchHit2.get("_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit2.get("description")); + List> hit2Details = getListOfValues(explanationForHit2, "details"); + assertEquals(2, hit2Details.size()); + + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit2DetailsForHit1.get("description")); + assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue((double) hit2DetailsForHit2.get("value") > 0.0f); + assertEquals("l2 normalization of:", hit2DetailsForHit2.get("description")); + assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = (Map) searchHit3.get("_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit3.get("description")); + List> hit3Details = getListOfValues(explanationForHit3, "details"); + assertEquals(1, hit3Details.size()); + + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit3DetailsForHit1.get("description")); + assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); + + Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); + assertEquals("within top 10", explanationsHit3.get("description")); + assertEquals(0, getListOfValues(explanationsHit3, "details").size()); + assertTrue((double) explanationsHit3.get("value") > 0.0f); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = (Map) searchHit4.get("_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(1, hit4Details.size()); + + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit4DetailsForHit1.get("description")); + assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); + + Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit4.get("description")); + assertEquals(0, getListOfValues(explanationsHit4, "details").size()); + assertTrue((double) explanationsHit4.get("value") > 0.0f); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenResponseHasQueryExplanations() { + try { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with normalization processor, no explanation response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map topLevelExplanationsHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(topLevelExplanationsHit1); + assertEquals(0.754f, (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "combined score of:"; + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); + List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); + assertEquals(1, normalizationExplanationHit1.size()); + Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); + assertEquals(0.754f, (double) hit1DetailsForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("sum of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("weight(test-text-field-1:place in 0) [PerFieldSimilarity], result of:", explanationsHit1.get("description")); + assertEquals(0.754f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, ((List) explanationsHit1.get("details")).size()); + + Map explanationsHit1Details = getListOfValues(explanationsHit1, "details").get(0); + assertEquals(0.754f, (double) explanationsHit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit1Details.get("description")); + assertEquals(3, getListOfValues(explanationsHit1Details, "details").size()); + + Map explanationsDetails1Hit1Details = getListOfValues(explanationsHit1Details, "details").get(0); + assertEquals(2.2f, (double) explanationsDetails1Hit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("boost", explanationsDetails1Hit1Details.get("description")); + assertEquals(0, getListOfValues(explanationsDetails1Hit1Details, "details").size()); + + Map explanationsDetails2Hit1Details = getListOfValues(explanationsHit1Details, "details").get(1); + assertEquals(0.693f, (double) explanationsDetails2Hit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:", explanationsDetails2Hit1Details.get("description")); + assertFalse(getListOfValues(explanationsDetails2Hit1Details, "details").isEmpty()); + + Map explanationsDetails3Hit1Details = getListOfValues(explanationsHit1Details, "details").get(2); + assertEquals(0.495f, (double) explanationsDetails3Hit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals( + "tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", + explanationsDetails3Hit1Details.get("description") + ); + assertFalse(getListOfValues(explanationsDetails3Hit1Details, "details").isEmpty()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map topLevelExplanationsHit2 = (Map) searchHit2.get("_explanation"); + assertNotNull(topLevelExplanationsHit2); + assertEquals(0.287f, (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); + List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); + assertEquals(1, normalizationExplanationHit2.size()); + + Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); + assertEquals(0.287f, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit2.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals(0.287f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2.get("description")); + assertEquals(3, getListOfValues(explanationsHit2, "details").size()); + + Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); + assertEquals(2.2f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("boost", explanationsHit2Details.get("description")); + assertEquals(0, getListOfValues(explanationsHit2Details, "details").size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(1); + Map topLevelExplanationsHit3 = (Map) searchHit3.get("_explanation"); + assertNotNull(topLevelExplanationsHit3); + assertEquals(0.287f, (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); + List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); + assertEquals(1, normalizationExplanationHit3.size()); + + Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); + assertEquals(0.287, (double) hit1DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit3.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); + + Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); + assertEquals(0.287f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3.get("description")); + assertEquals(3, getListOfValues(explanationsHit3, "details").size()); + + Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); + assertEquals(2.2f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("boost", explanationsHit3Details.get("description")); + assertEquals(0, getListOfValues(explanationsHit3Details, "details").size()); + } finally { + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void initializeIndexIfNotExist(String indexName) { + if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { + prepareKnnIndex( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "1", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector1).toArray(), Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "2", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector2).toArray(), Floats.asList(testVector2).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "3", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector3).toArray(), Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + assertEquals(3, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); + } + + if (TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + 1 + ), + "" + ); + addDocsToIndex(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME); + } + + if (TEST_MULTI_DOC_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(), + 1 + ), + "" + ); + addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); + } + } + + private void addDocsToIndex(final String testMultiDocIndexName) { + addKnnDoc( + testMultiDocIndexName, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + testMultiDocIndexName, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + testMultiDocIndexName, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + testMultiDocIndexName, + "4", + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + assertEquals(4, getDocCount(testMultiDocIndexName)); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index ce75d9fba..610e08dd0 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -12,13 +12,9 @@ import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; -import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -28,7 +24,6 @@ import org.apache.commons.lang.RandomStringUtils; import org.apache.commons.lang.math.RandomUtils; -import org.apache.commons.lang3.Range; import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.ResponseException; @@ -39,7 +34,6 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.primitives.Floats; @@ -88,7 +82,6 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; protected static final int SINGLE_SHARD = 1; protected static final int MULTIPLE_SHARDS = 3; - public static final String NORMALIZATION_TECHNIQUE_L2 = "l2"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -840,305 +833,6 @@ public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { } } - @SneakyThrows - public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { - try { - initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); - // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); - - TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); - TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); - - HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); - hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); - hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); - - Map searchResponseAsMap1 = search( - TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, - hybridQueryBuilderNeuralThenTerm, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) - ); - // Assert - // search hits - assertEquals(3, getHitCount(searchResponseAsMap1)); - - List> hitsNestedList = getNestedHits(searchResponseAsMap1); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hitsNestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - assertEquals(Set.copyOf(ids).size(), ids.size()); - - Map total = getTotalHits(searchResponseAsMap1); - assertNotNull(total.get("value")); - assertEquals(3, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - - // explain - Map searchHit1 = hitsNestedList.get(0); - Map explanationForHit1 = (Map) searchHit1.get("_explanation"); - assertNotNull(explanationForHit1); - assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); - String expectedGeneralCombineScoreDescription = - "combined score with techniques: normalization [min_max], combination [arithmetic_mean] with optional parameter [weights]: []"; - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); - List> hit1Details = (List>) explanationForHit1.get("details"); - assertEquals(3, hit1Details.size()); - Map hit1DetailsForHit1 = hit1Details.get(0); - assertEquals(1.0, hit1DetailsForHit1.get("value")); - assertTrue( - ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[1\\.0\\]") - ); - assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); - - Map hit1DetailsForHit2 = hit1Details.get(1); - assertEquals(0.5, hit1DetailsForHit2.get("value")); - assertEquals("normalized scores: [0.0, 1.0] combined to a final score: 0.5", hit1DetailsForHit2.get("description")); - assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); - - Map hit1DetailsForHit3 = hit1Details.get(2); - double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); - assertTrue(actualHit1ScoreHit3 > 0.0); - assertEquals("base scores from subqueries:", hit1DetailsForHit3.get("description")); - assertEquals(1, ((List) hit1DetailsForHit3.get("details")).size()); - - Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); - assertEquals(actualHit1ScoreHit3, ((double) hit1SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("sum of:", hit1SubDetailsForHit3.get("description")); - assertEquals(1, ((List) hit1SubDetailsForHit3.get("details")).size()); - - // search hit 2 - Map searchHit2 = hitsNestedList.get(1); - Map explanationForHit2 = (Map) searchHit2.get("_explanation"); - assertNotNull(explanationForHit2); - assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit2.get("description")); - List> hit2Details = (List>) explanationForHit2.get("details"); - assertEquals(3, hit2Details.size()); - Map hit2DetailsForHit1 = hit2Details.get(0); - assertEquals(1.0, hit2DetailsForHit1.get("value")); - assertTrue( - ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[1\\.0\\]") - ); - assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); - - Map hit2DetailsForHit2 = hit2Details.get(1); - assertEquals(0.5, hit2DetailsForHit2.get("value")); - assertEquals("normalized scores: [1.0, 0.0] combined to a final score: 0.5", hit2DetailsForHit2.get("description")); - assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); - - Map hit2DetailsForHit3 = hit2Details.get(2); - double actualHit2ScoreHit3 = ((double) hit2DetailsForHit3.get("value")); - assertTrue(actualHit2ScoreHit3 > 0.0); - assertEquals("base scores from subqueries:", hit2DetailsForHit3.get("description")); - assertEquals(1, ((List) hit2DetailsForHit3.get("details")).size()); - - Map hit2SubDetailsForHit3 = (Map) ((List) hit2DetailsForHit3.get("details")).get(0); - assertEquals(actualHit2ScoreHit3, ((double) hit2SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", hit2SubDetailsForHit3.get("description")); - assertEquals(1, ((List) hit2SubDetailsForHit3.get("details")).size()); - - // search hit 3 - Map searchHit3 = hitsNestedList.get(2); - Map explanationForHit3 = (Map) searchHit3.get("_explanation"); - assertNotNull(explanationForHit3); - assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit3.get("description")); - List> hit3Details = (List>) explanationForHit3.get("details"); - assertEquals(3, hit3Details.size()); - Map hit3DetailsForHit1 = hit3Details.get(0); - assertEquals(0.001, hit3DetailsForHit1.get("value")); - assertTrue( - ((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[0\\.001\\]") - ); - assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); - - Map hit3DetailsForHit2 = hit3Details.get(1); - assertEquals(0.0005, hit3DetailsForHit2.get("value")); - assertEquals("normalized scores: [0.0, 0.001] combined to a final score: 5.0E-4", hit3DetailsForHit2.get("description")); - assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); - - Map hit3DetailsForHit3 = hit3Details.get(2); - double actualHit3ScoreHit3 = ((double) hit3DetailsForHit3.get("value")); - assertTrue(actualHit3ScoreHit3 > 0.0); - assertEquals("base scores from subqueries:", hit3DetailsForHit3.get("description")); - assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); - - Map hit3SubDetailsForHit3 = (Map) ((List) hit3DetailsForHit3.get("details")).get(0); - assertEquals(actualHit3ScoreHit3, ((double) hit3SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("sum of:", hit3SubDetailsForHit3.get("description")); - assertEquals(1, ((List) hit3SubDetailsForHit3.get("details")).size()); - } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); - } - } - - @SneakyThrows - public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() { - try { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); - createSearchPipeline( - SEARCH_PIPELINE, - NORMALIZATION_TECHNIQUE_L2, - DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), - true - ); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() - .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) - .vector(createRandomVector(TEST_DIMENSION)) - .k(10) - .build(); - hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); - hybridQueryBuilder.add(knnQueryBuilder); - - Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_NAME, - hybridQueryBuilder, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) - ); - // Assert - // basic sanity check for search hits - assertEquals(4, getHitCount(searchResponseAsMap)); - assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - float actualMaxScore = getMaxScore(searchResponseAsMap).get(); - assertTrue(actualMaxScore > 0); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(4, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - - // explain - List> hitsNestedList = getNestedHits(searchResponseAsMap); - Map searchHit1 = hitsNestedList.get(0); - Map explanationForHit1 = (Map) searchHit1.get("_explanation"); - assertNotNull(explanationForHit1); - assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); - String expectedGeneralCombineScoreDescription = - "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameter [weights]: " - + Arrays.toString(new float[] { 0.3f, 0.7f }); - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); - List> hit1Details = (List>) explanationForHit1.get("details"); - assertEquals(3, hit1Details.size()); - Map hit1DetailsForHit1 = hit1Details.get(0); - assertTrue((double) hit1DetailsForHit1.get("value") > 0.5f); - assertTrue( - ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\] normalized to scores: \\[.*, .*\\]") - ); - assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); - - Map hit1DetailsForHit2 = hit1Details.get(1); - assertEquals(actualMaxScore, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue( - ((String) hit1DetailsForHit2.get("description")).matches("normalized scores: \\[.*, .*\\] combined to a final score: .*") - ); - assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); - - Map hit1DetailsForHit3 = hit1Details.get(2); - assertEquals(1.0, (double) hit1DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit1DetailsForHit3.get("description")).matches("base scores from subqueries:")); - assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); - - // hit 2 - Map searchHit2 = hitsNestedList.get(1); - Map explanationForHit2 = (Map) searchHit2.get("_explanation"); - assertNotNull(explanationForHit2); - assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit2.get("description")); - List> hit2Details = (List>) explanationForHit2.get("details"); - assertEquals(3, hit2Details.size()); - Map hit2DetailsForHit1 = hit2Details.get(0); - assertTrue((double) hit2DetailsForHit1.get("value") > 0.5f); - assertTrue( - ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\] normalized to scores: \\[.*, .*\\]") - ); - assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); - - Map hit2DetailsForHit2 = hit2Details.get(1); - assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit2DetailsForHit2.get("value"))); - assertTrue( - ((String) hit2DetailsForHit2.get("description")).matches("normalized scores: \\[.*, .*\\] combined to a final score: .*") - ); - assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); - - Map hit2DetailsForHit3 = hit2Details.get(2); - assertEquals(1.0, (double) hit2DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit2DetailsForHit3.get("description")).matches("base scores from subqueries:")); - assertEquals(2, ((List) hit2DetailsForHit3.get("details")).size()); - - // hit 3 - Map searchHit3 = hitsNestedList.get(2); - Map explanationForHit3 = (Map) searchHit3.get("_explanation"); - assertNotNull(explanationForHit3); - assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit3.get("description")); - List> hit3Details = (List>) explanationForHit3.get("details"); - assertEquals(3, hit3Details.size()); - Map hit3DetailsForHit1 = hit3Details.get(0); - assertTrue((double) hit3DetailsForHit1.get("value") > 0.5f); - assertTrue(((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[.*\\]")); - assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); - - Map hit3DetailsForHit2 = hit3Details.get(1); - assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit2.get("value"))); - assertTrue( - ((String) hit3DetailsForHit2.get("description")).matches("normalized scores: \\[0.0, .*\\] combined to a final score: .*") - ); - assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); - - Map hit3DetailsForHit3 = hit3Details.get(2); - assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit3.get("value"))); - assertTrue(((String) hit3DetailsForHit3.get("description")).matches("base scores from subqueries:")); - assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); - - // hit 4 - Map searchHit4 = hitsNestedList.get(3); - Map explanationForHit4 = (Map) searchHit4.get("_explanation"); - assertNotNull(explanationForHit4); - assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedGeneralCombineScoreDescription, explanationForHit4.get("description")); - List> hit4Details = (List>) explanationForHit4.get("details"); - assertEquals(3, hit4Details.size()); - Map hit4DetailsForHit1 = hit4Details.get(0); - assertTrue((double) hit4DetailsForHit1.get("value") > 0.5f); - assertTrue(((String) hit4DetailsForHit1.get("description")).matches("source scores: \\[1.0\\] normalized to scores: \\[.*\\]")); - assertEquals(0, ((List) hit4DetailsForHit1.get("details")).size()); - - Map hit4DetailsForHit2 = hit4Details.get(1); - assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit4DetailsForHit2.get("value"))); - assertTrue( - ((String) hit4DetailsForHit2.get("description")).matches("normalized scores: \\[.*, 0.0\\] combined to a final score: .*") - ); - assertEquals(0, ((List) hit4DetailsForHit2.get("details")).size()); - - Map hit4DetailsForHit3 = hit4Details.get(2); - assertEquals(1.0, (double) hit4DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit4DetailsForHit3.get("description")).matches("base scores from subqueries:")); - assertEquals(1, ((List) hit4DetailsForHit3.get("details")).size()); - } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); - } - } - @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index 489ebc520..875c66310 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -23,7 +23,6 @@ import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQueryWhenSortIsEnabled; import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import org.opensearch.search.sort.SortOrder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; @@ -575,77 +574,75 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { Map explanationForHit1 = (Map) searchHit1.get("_explanation"); assertNotNull(explanationForHit1); assertNull(searchHit1.get("_score")); - String expectedGeneralCombineScoreDescription = - "combined score with techniques: normalization [min_max], combination [arithmetic_mean] with optional parameter [weights]: []"; + String expectedGeneralCombineScoreDescription = "arithmetic_mean combination of:"; assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); - List> hit1Details = (List>) explanationForHit1.get("details"); - assertEquals(3, hit1Details.size()); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); Map hit1DetailsForHit1 = hit1Details.get(0); assertEquals(1.0, hit1DetailsForHit1.get("value")); - assertTrue( - ((String) hit1DetailsForHit1.get("description")).matches( - "source scores: \\[0.4700036, 1.0\\] normalized to scores: \\[1.0, 1.0\\]" - ) + assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); + List> hit1DetailsForHit1Details = getListOfValues(hit1DetailsForHit1, "details"); + assertEquals(1, hit1DetailsForHit1Details.size()); + + Map hit1DetailsForHit1DetailsForHit1 = hit1DetailsForHit1Details.get(0); + assertEquals("weight(name:mission in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit1DetailsForHit1.get("description")); + assertTrue((double) hit1DetailsForHit1DetailsForHit1.get("value") > 0.0f); + assertEquals(1, getListOfValues(hit1DetailsForHit1DetailsForHit1, "details").size()); + + Map hit1DetailsForHit1DetailsForHit1DetailsForHit1 = getListOfValues( + hit1DetailsForHit1DetailsForHit1, + "details" + ).get(0); + assertEquals( + "score(freq=1.0), computed as boost * idf * tf from:", + hit1DetailsForHit1DetailsForHit1DetailsForHit1.get("description") + ); + assertTrue((double) hit1DetailsForHit1DetailsForHit1DetailsForHit1.get("value") > 0.0f); + assertEquals(3, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").size()); + + assertEquals("boost", getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("description")); + assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("value") > 0.0f); + assertEquals( + "idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:", + getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("description") + ); + assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("value") > 0.0f); + assertEquals( + "tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", + getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(2).get("description") ); - assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); - - Map hit1DetailsForHit2 = hit1Details.get(1); - assertEquals(0.666, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("normalized scores: [1.0, 0.0, 1.0] combined to a final score: 0.6666667", hit1DetailsForHit2.get("description")); - assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); - - Map hit1DetailsForHit3 = hit1Details.get(2); - double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); - assertTrue(actualHit1ScoreHit3 > 0.0); - assertEquals("base scores from subqueries:", hit1DetailsForHit3.get("description")); - assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); - - Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); - assertEquals(0.47, ((double) hit1SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(name:mission in 0) [PerFieldSimilarity], result of:", hit1SubDetailsForHit3.get("description")); - assertEquals(1, ((List) hit1SubDetailsForHit3.get("details")).size()); - - Map hit2SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(1); - assertEquals(1.0f, ((double) hit2SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("stock:[20 TO 400]", hit2SubDetailsForHit3.get("description")); - assertEquals(0, ((List) hit2SubDetailsForHit3.get("details")).size()); + assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(2).get("value") > 0.0f); + // hit 4 Map searchHit4 = nestedHits.get(3); Map explanationForHit4 = (Map) searchHit4.get("_explanation"); assertNotNull(explanationForHit4); assertNull(searchHit4.get("_score")); assertEquals(expectedGeneralCombineScoreDescription, explanationForHit4.get("description")); - List> hit4Details = (List>) explanationForHit4.get("details"); - assertEquals(3, hit4Details.size()); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(2, hit4Details.size()); Map hit1DetailsForHit4 = hit4Details.get(0); assertEquals(1.0, hit1DetailsForHit4.get("value")); - assertTrue( - ((String) hit1DetailsForHit4.get("description")).matches( - "source scores: \\[0.9808291, 1.0\\] normalized to scores: \\[1.0, 1.0\\]" - ) + assertEquals("min_max normalization of:", hit1DetailsForHit4.get("description")); + assertEquals(1, ((List) hit1DetailsForHit4.get("details")).size()); + List> hit1DetailsForHit4Details = getListOfValues(hit1DetailsForHit4, "details"); + assertEquals(1, hit1DetailsForHit4Details.size()); + + Map hit1DetailsForHit1DetailsForHit4 = hit1DetailsForHit4Details.get(0); + assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit1DetailsForHit4.get("description")); + assertTrue((double) hit1DetailsForHit1DetailsForHit4.get("value") > 0.0f); + assertEquals(1, getListOfValues(hit1DetailsForHit1DetailsForHit4, "details").size()); + + Map hit1DetailsForHit1DetailsForHit1DetailsForHit4 = getListOfValues( + hit1DetailsForHit1DetailsForHit4, + "details" + ).get(0); + assertEquals( + "score(freq=1.0), computed as boost * idf * tf from:", + hit1DetailsForHit1DetailsForHit1DetailsForHit4.get("description") ); - assertEquals(0, ((List) hit1DetailsForHit4.get("details")).size()); - - Map hit2DetailsForHit4 = hit4Details.get(1); - assertEquals(0.666, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("normalized scores: [0.0, 1.0, 1.0] combined to a final score: 0.6666667", hit2DetailsForHit4.get("description")); - assertEquals(0, ((List) hit2DetailsForHit4.get("details")).size()); - - Map hit3DetailsForHit4 = hit4Details.get(2); - double actualHit3ScoreHit4 = ((double) hit3DetailsForHit4.get("value")); - assertTrue(actualHit3ScoreHit4 > 0.0); - assertEquals("base scores from subqueries:", hit3DetailsForHit4.get("description")); - assertEquals(2, ((List) hit3DetailsForHit4.get("details")).size()); - - Map hit1SubDetailsForHit4 = (Map) ((List) hit3DetailsForHit4.get("details")).get(0); - assertEquals(0.98, ((double) hit1SubDetailsForHit4.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1SubDetailsForHit4.get("description")); - assertEquals(1, ((List) hit1SubDetailsForHit4.get("details")).size()); - - Map hit2SubDetailsForHit4 = (Map) ((List) hit3DetailsForHit4.get("details")).get(1); - assertEquals(1.0f, ((double) hit2SubDetailsForHit4.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("stock:[20 TO 400]", hit2SubDetailsForHit4.get("description")); - assertEquals(0, ((List) hit2SubDetailsForHit4.get("details")).size()); + assertTrue((double) hit1DetailsForHit1DetailsForHit1DetailsForHit4.get("value") > 0.0f); + assertEquals(3, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit4, "details").size()); // hit 6 Map searchHit6 = nestedHits.get(5); @@ -653,28 +650,19 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { assertNotNull(explanationForHit6); assertNull(searchHit6.get("_score")); assertEquals(expectedGeneralCombineScoreDescription, explanationForHit6.get("description")); - List> hit6Details = (List>) explanationForHit6.get("details"); - assertEquals(3, hit6Details.size()); + List> hit6Details = getListOfValues(explanationForHit6, "details"); + assertEquals(1, hit6Details.size()); Map hit1DetailsForHit6 = hit6Details.get(0); assertEquals(1.0, hit1DetailsForHit6.get("value")); - assertEquals("source scores: [1.0] normalized to scores: [1.0]", hit1DetailsForHit6.get("description")); - assertEquals(0, ((List) hit1DetailsForHit6.get("details")).size()); - - Map hit2DetailsForHit6 = hit6Details.get(1); - assertEquals(0.333, (double) hit2DetailsForHit6.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("normalized scores: [0.0, 0.0, 1.0] combined to a final score: 0.33333334", hit2DetailsForHit6.get("description")); - assertEquals(0, ((List) hit2DetailsForHit6.get("details")).size()); - - Map hit3DetailsForHit6 = hit6Details.get(2); - double actualHit3ScoreHit6 = ((double) hit3DetailsForHit6.get("value")); - assertTrue(actualHit3ScoreHit6 > 0.0); - assertEquals("base scores from subqueries:", hit3DetailsForHit6.get("description")); - assertEquals(1, ((List) hit3DetailsForHit6.get("details")).size()); - - Map hit1SubDetailsForHit6 = (Map) ((List) hit3DetailsForHit6.get("details")).get(0); - assertEquals(1.0, ((double) hit1SubDetailsForHit6.get("value")), DELTA_FOR_SCORE_ASSERTION); - assertEquals("stock:[20 TO 400]", hit1SubDetailsForHit6.get("description")); - assertEquals(0, ((List) hit1SubDetailsForHit6.get("details")).size()); + assertEquals("min_max normalization of:", hit1DetailsForHit6.get("description")); + assertEquals(1, ((List) hit1DetailsForHit6.get("details")).size()); + List> hit1DetailsForHit6Details = getListOfValues(hit1DetailsForHit6, "details"); + assertEquals(1, hit1DetailsForHit6Details.size()); + + Map hit1DetailsForHit1DetailsForHit6 = hit1DetailsForHit6Details.get(0); + assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit1DetailsForHit4.get("description")); + assertTrue((double) hit1DetailsForHit1DetailsForHit6.get("value") > 0.0f); + assertEquals(0, getListOfValues(hit1DetailsForHit1DetailsForHit6, "details").size()); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 5761c5569..4f154e78b 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -866,6 +866,10 @@ protected List getNormalizationScoreList(final Map searc return scores; } + protected List> getListOfValues(Map searchResponseAsMap, String key) { + return (List>) searchResponseAsMap.get(key); + } + /** * Create a k-NN index from a list of KNNFieldConfigs * diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index c10380e87..af9b37b14 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -352,17 +352,17 @@ public static void assertHitResultsFromQueryWhenSortIsEnabled( assertEquals(RELATION_EQUAL_TO, total.get("relation")); } - private static List> getNestedHits(Map searchResponseAsMap) { + public static List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); } - private static Map getTotalHits(Map searchResponseAsMap) { + public static Map getTotalHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (Map) hitsMap.get("total"); } - private static Optional getMaxScore(Map searchResponseAsMap) { + public static Optional getMaxScore(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } From 7a9508740922c1cacc89edef4fcd24cec71eab67 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 13 Nov 2024 11:41:12 -0800 Subject: [PATCH 09/11] Convert record to lombok value, add unit tests Signed-off-by: Martin Gaievski --- .../ExplanationResponseProcessor.java | 6 +- .../NormalizationProcessorWorkflow.java | 2 +- .../neuralsearch/processor/SearchShard.java | 9 +- .../processor/explain/DocIdAtSearchShard.java | 8 +- .../processor/explain/ExplanationDetails.java | 10 +- .../processor/explain/ExplanationUtils.java | 4 + .../explain/ExplanationUtilsTests.java | 115 ++++++++++++++++++ ...lanationResponseProcessorFactoryTests.java | 112 +++++++++++++++++ 8 files changed, 255 insertions(+), 11 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 74ae0621a..01c1516d2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -89,9 +89,9 @@ public SearchResponse processResponse( for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { normalizedExplanation[i] = Explanation.match( // normalized score - normalizationExplanation.scoreDetails().get(i).getKey(), + normalizationExplanation.getScoreDetails().get(i).getKey(), // description of normalized score - normalizationExplanation.scoreDetails().get(i).getValue(), + normalizationExplanation.getScoreDetails().get(i).getValue(), // shard level details queryLevelExplanation.getDetails()[i] ); @@ -99,7 +99,7 @@ public SearchResponse processResponse( Explanation finalExplanation = Explanation.match( searchHit.getScore(), // combination level explanation is always a single detail - combinationExplanation.scoreDetails().get(0).getValue(), + combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation ); searchHit.explanation(finalExplanation); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 1a958676a..078c68aff 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -125,7 +125,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< Map> combinedExplanations = combinationExplain.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> { - DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), entry.getKey()); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey()); return CombinedExplanationDetails.builder() .normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)) .combinationExplanations(explainDetail) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java index 505b19ae0..c875eab55 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -4,12 +4,19 @@ */ package org.opensearch.neuralsearch.processor; +import lombok.AllArgsConstructor; +import lombok.Value; import org.opensearch.search.SearchShardTarget; /** * DTO class to store index, shardId and nodeId for a search shard. */ -public record SearchShard(String index, int shardId, String nodeId) { +@Value +@AllArgsConstructor +public class SearchShard { + String index; + int shardId; + String nodeId; /** * Create SearchShard from SearchShardTarget diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java index 9ce4ebf97..51550e523 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java @@ -4,13 +4,15 @@ */ package org.opensearch.neuralsearch.processor.explain; +import lombok.Value; import org.opensearch.neuralsearch.processor.SearchShard; /** * DTO class to store docId and search shard for a query. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. - * @param docId - * @param searchShard */ -public record DocIdAtSearchShard(int docId, SearchShard searchShard) { +@Value +public class DocIdAtSearchShard { + int docId; + SearchShard searchShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index fe009f383..e577e6f43 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -4,6 +4,8 @@ */ package org.opensearch.neuralsearch.processor.explain; +import lombok.AllArgsConstructor; +import lombok.Value; import org.apache.commons.lang3.tuple.Pair; import java.util.List; @@ -11,10 +13,12 @@ /** * DTO class to store value and description for explain details. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. - * @param docId iterator based id of the document - * @param scoreDetails list of score details for the document, each Pair object contains score and description of the score */ -public record ExplanationDetails(int docId, List> scoreDetails) { +@Value +@AllArgsConstructor +public class ExplanationDetails { + int docId; + List> scoreDetails; public ExplanationDetails(List> scoreDetails) { this(-1, scoreDetails); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java index 499ce77cf..b4c5cd557 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -45,6 +46,9 @@ public static Map getDocIdAtQueryForNorm * @return a string describing the combination technique and its parameters */ public static String describeCombinationTechnique(final String techniqueName, final List weights) { + if (Objects.isNull(techniqueName)) { + throw new IllegalArgumentException("combination technique name cannot be null"); + } return Optional.ofNullable(weights) .filter(w -> !w.isEmpty()) .map(w -> String.format(Locale.ROOT, "%s, weights %s", techniqueName, weights)) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java new file mode 100644 index 000000000..becab3860 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Before; + +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ExplanationUtilsTests extends OpenSearchQueryTestCase { + + private DocIdAtSearchShard docId1; + private DocIdAtSearchShard docId2; + private Map> normalizedScores; + private final MinMaxScoreNormalizationTechnique MIN_MAX_TECHNIQUE = new MinMaxScoreNormalizationTechnique(); + + @Before + public void setUp() throws Exception { + super.setUp(); + SearchShard searchShard = new SearchShard("test_index", 0, "abcdefg"); + docId1 = new DocIdAtSearchShard(1, searchShard); + docId2 = new DocIdAtSearchShard(2, searchShard); + normalizedScores = new HashMap<>(); + } + + public void testGetDocIdAtQueryForNormalization() { + // Setup + normalizedScores.put(docId1, Arrays.asList(1.0f, 0.5f)); + normalizedScores.put(docId2, Arrays.asList(0.8f)); + // Act + Map result = ExplanationUtils.getDocIdAtQueryForNormalization( + normalizedScores, + MIN_MAX_TECHNIQUE + ); + // Assert + assertNotNull(result); + assertEquals(2, result.size()); + + // Assert first document + ExplanationDetails details1 = result.get(docId1); + assertNotNull(details1); + List> explanations1 = details1.getScoreDetails(); + assertEquals(2, explanations1.size()); + assertEquals(1.0f, explanations1.get(0).getLeft(), 0.001); + assertEquals(0.5f, explanations1.get(1).getLeft(), 0.001); + assertEquals("min_max normalization of:", explanations1.get(0).getRight()); + assertEquals("min_max normalization of:", explanations1.get(1).getRight()); + + // Assert second document + ExplanationDetails details2 = result.get(docId2); + assertNotNull(details2); + List> explanations2 = details2.getScoreDetails(); + assertEquals(1, explanations2.size()); + assertEquals(0.8f, explanations2.get(0).getLeft(), 0.001); + assertEquals("min_max normalization of:", explanations2.get(0).getRight()); + } + + public void testGetDocIdAtQueryForNormalizationWithEmptyScores() { + // Setup + // Using empty normalizedScores from setUp + // Act + Map result = ExplanationUtils.getDocIdAtQueryForNormalization( + normalizedScores, + MIN_MAX_TECHNIQUE + ); + // Assert + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + public void testDescribeCombinationTechniqueWithWeights() { + // Setup + String techniqueName = "test_technique"; + List weights = Arrays.asList(0.3f, 0.7f); + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights); + // Assert + assertEquals("test_technique, weights [0.3, 0.7]", result); + } + + public void testDescribeCombinationTechniqueWithoutWeights() { + // Setup + String techniqueName = "test_technique"; + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, null); + // Assert + assertEquals("test_technique", result); + } + + public void testDescribeCombinationTechniqueWithEmptyWeights() { + // Setup + String techniqueName = "test_technique"; + List weights = Arrays.asList(); + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights); + // Assert + assertEquals("test_technique", result); + } + + public void testDescribeCombinationTechniqueWithNullTechnique() { + // Setup + List weights = Arrays.asList(1.0f); + // Act & Assert + expectThrows(IllegalArgumentException.class, () -> ExplanationUtils.describeCombinationTechnique(null, weights)); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java new file mode 100644 index 000000000..453cc471c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class ExplanationResponseProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoParams_thenSuccessful() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + // Assert + assertProcessor(responseProcessor, tag, description, ignoreFailure); + } + + @SneakyThrows + public void testInvalidInput_whenParamsPassedToFactory_thenSuccessful() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + // create map of random parameters + Map config = new HashMap<>(); + for (int i = 0; i < randomInt(1_000); i++) { + config.put(randomAlphaOfLength(10) + i, randomAlphaOfLength(100)); + } + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + // Assert + assertProcessor(responseProcessor, tag, description, ignoreFailure); + } + + @SneakyThrows + public void testNewInstanceCreation_whenCreateMultipleTimes_thenNewInstanceReturned() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessorOne = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + + SearchResponseProcessor responseProcessorTwo = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + + // Assert + assertNotEquals(responseProcessorOne, responseProcessorTwo); + } + + private static void assertProcessor(SearchResponseProcessor responseProcessor, String tag, String description, boolean ignoreFailure) { + assertNotNull(responseProcessor); + assertTrue(responseProcessor instanceof ExplanationResponseProcessor); + ExplanationResponseProcessor explanationResponseProcessor = (ExplanationResponseProcessor) responseProcessor; + assertEquals("explanation_response_processor", explanationResponseProcessor.getType()); + assertEquals(tag, explanationResponseProcessor.getTag()); + assertEquals(description, explanationResponseProcessor.getDescription()); + assertEquals(ignoreFailure, explanationResponseProcessor.isIgnoreFailure()); + } +} From 90987810004b22ce8bb65b1d726d14a9dd49e47b Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 14 Nov 2024 21:35:04 -0800 Subject: [PATCH 10/11] Address revie comments Signed-off-by: Martin Gaievski --- .../ExplanationResponseProcessor.java | 17 +++- .../NormalizationProcessorWorkflow.java | 20 ++-- .../processor/combination/ScoreCombiner.java | 26 ++--- .../processor/explain/ExplanationDetails.java | 1 + .../processor/explain/ExplanationUtils.java | 23 +++-- .../query/HybridQueryExplainIT.java | 97 +++++++++++++++++++ 6 files changed, 154 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 01c1516d2..01cdfcb0d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -40,11 +40,17 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor { private final String tag; private final boolean ignoreFailure; + /** + * Add explanation details to search response if it is present in request context + */ @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response) { return processResponse(request, response, null); } + /** + * Combines explanation from processor with search hits level explanations and adds it to search response + */ @Override public SearchResponse processResponse( final SearchRequest request, @@ -56,15 +62,20 @@ public SearchResponse processResponse( || requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) { return response; } + // Extract explanation payload from context ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY); Map explainPayload = explanationPayload.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { + // for score normalization, processor level explanations will be sorted in scope of each shard, + // and we are merging both into a single sorted list SearchHits searchHits = response.getHits(); SearchHit[] searchHitsArray = searchHits.getHits(); // create a map of searchShard and list of indexes of search hit objects in search hits array // the list will keep original order of sorting as per final search results Map> searchHitsByShard = new HashMap<>(); + // we keep index for each shard, where index is a position in searchHitsByShard list Map explainsByShardCount = new HashMap<>(); + // Build initial shard mappings for (int i = 0; i < searchHitsArray.length; i++) { SearchHit searchHit = searchHitsArray[i]; SearchShardTarget searchShardTarget = searchHit.getShard(); @@ -72,19 +83,22 @@ public SearchResponse processResponse( searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); explainsByShardCount.putIfAbsent(searchShard, -1); } + // Process normalization details if available in correct format if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map) { @SuppressWarnings("unchecked") Map> combinedExplainDetails = (Map< SearchShard, List>) explainPayload.get(NORMALIZATION_PROCESSOR); - + // Process each search hit to add processor level explanations for (SearchHit searchHit : searchHitsArray) { SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); + // Extract various explanation components Explanation queryLevelExplanation = searchHit.getExplanation(); ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations(); ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations(); + // Create normalized explanations for each detail Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length]; for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { normalizedExplanation[i] = Explanation.match( @@ -96,6 +110,7 @@ public SearchResponse processResponse( queryLevelExplanation.getDetails()[i] ); } + // Create and set final explanation combining all components Explanation finalExplanation = Explanation.match( searchHit.getScore(), // combination level explanation is always a single detail diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 078c68aff..f2699d967 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.processor; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -106,6 +107,10 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); } + /** + * Collects explanations from normalization and combination techniques and save thme into pipeline context. Later that + * information will be read by the response processor to add it to search response + */ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List queryTopDocs) { if (!request.isExplain()) { return; @@ -122,15 +127,19 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< request.getCombinationTechnique(), sortForQuery ); - Map> combinedExplanations = combinationExplain.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> { + Map> combinedExplanations = new HashMap<>(); + for (Map.Entry> entry : combinationExplain.entrySet()) { + List combinedDetailsList = new ArrayList<>(); + for (ExplanationDetails explainDetail : entry.getValue()) { DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey()); - return CombinedExplanationDetails.builder() + CombinedExplanationDetails combinedDetail = CombinedExplanationDetails.builder() .normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)) .combinationExplanations(explainDetail) .build(); - }).collect(Collectors.toList()))); + combinedDetailsList.add(combinedDetail); + } + combinedExplanations.put(entry.getKey(), combinedDetailsList); + } ExplanationPayload explanationPayload = ExplanationPayload.builder() .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations)) @@ -139,7 +148,6 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); pipelineProcessingContext.setAttribute(EXPLANATION_RESPONSE_KEY, explanationPayload); } - } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index cbc3f485b..1779f20f7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -359,19 +359,19 @@ private List explainByShard( // sort combined scores as per sorting criteria - either score desc or field sorting Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - List listOfExplanations = sortedDocsIds.stream() - .map( - docId -> new ExplanationDetails( - docId, - List.of( - Pair.of( - combinedNormalizedScoresByDocId.get(docId), - String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique) scoreCombinationTechnique).describe()) - ) - ) - ) - ) - .toList(); + List listOfExplanations = new ArrayList<>(); + String combinationDescription = String.format( + Locale.ROOT, + "%s combination of:", + ((ExplainableTechnique) scoreCombinationTechnique).describe() + ); + for (int docId : sortedDocsIds) { + ExplanationDetails explanation = new ExplanationDetails( + docId, + List.of(Pair.of(combinedNormalizedScoresByDocId.get(docId), combinationDescription)) + ); + listOfExplanations.add(explanation); + } return listOfExplanations; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index e577e6f43..2816a348b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -21,6 +21,7 @@ public class ExplanationDetails { List> scoreDetails; public ExplanationDetails(List> scoreDetails) { + // pass docId as -1 to match docId in SearchHit https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchHit.java#L170 this(-1, scoreDetails); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java index b4c5cd557..c6ac0500b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java @@ -6,12 +6,13 @@ import org.apache.commons.lang3.tuple.Pair; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.stream.Collectors; /** * Utility class for explain functionality @@ -27,15 +28,17 @@ public static Map getDocIdAtQueryForNorm final Map> normalizedScores, final ExplainableTechnique technique ) { - Map explain = normalizedScores.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> { - List normScores = normalizedScores.get(entry.getKey()); - List> explanations = normScores.stream() - .map(score -> Pair.of(score, String.format(Locale.ROOT, "%s normalization of:", technique.describe()))) - .collect(Collectors.toList()); - return new ExplanationDetails(explanations); - })); + Map explain = new HashMap<>(); + for (Map.Entry> entry : normalizedScores.entrySet()) { + List normScores = normalizedScores.get(entry.getKey()); + List> explanations = new ArrayList<>(); + for (float score : normScores) { + String description = String.format(Locale.ROOT, "%s normalization of:", technique.describe()); + explanations.add(Pair.of(score, description)); + } + explain.put(entry.getKey(), new ExplanationDetails(explanations)); + } + return explain; } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index a7656912c..3b1d6cfba 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -17,6 +17,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.stream.IntStream; @@ -37,6 +38,7 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index"; private static final String TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME = "test-hybrid-multi-doc-nested-fields-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-hybrid-multi-doc-index"; + private static final String TEST_LARGE_DOCS_INDEX_NAME = "test-hybrid-large-docs-index"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "place"; @@ -459,6 +461,64 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe } } + @SneakyThrows + public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_LARGE_DOCS_INDEX_NAME, + hybridQueryBuilder, + null, + 1000, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + assertNotNull(hitsNestedList); + assertFalse(hitsNestedList.isEmpty()); + + // Verify total hits + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // Sanity checks for each hit's explanation + for (Map hit : hitsNestedList) { + // Verify score is positive + double score = (double) hit.get("_score"); + assertTrue("Score should be positive", score > 0.0); + + // Basic explanation structure checks + Map explanation = (Map) hit.get("_explanation"); + assertNotNull(explanation); + assertEquals("arithmetic_mean combination of:", explanation.get("description")); + Map hitDetailsForHit = getListOfValues(explanation, "details").get(0); + assertTrue((double) hitDetailsForHit.get("value") > 0.0f); + assertEquals("min_max normalization of:", hitDetailsForHit.get("description")); + Map subQueryDetailsForHit = getListOfValues(hitDetailsForHit, "details").get(0); + assertTrue((double) subQueryDetailsForHit.get("value") > 0.0f); + assertFalse(subQueryDetailsForHit.get("description").toString().isEmpty()); + assertEquals(1, getListOfValues(subQueryDetailsForHit, "details").size()); + } + // Verify scores are properly ordered + List scores = new ArrayList<>(); + for (Map hit : hitsNestedList) { + scores.add((Double) hit.get("_score")); + } + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); + } finally { + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { @@ -521,6 +581,43 @@ private void initializeIndexIfNotExist(String indexName) { ); addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); } + + if (TEST_LARGE_DOCS_INDEX_NAME.equals(indexName) && !indexExists(TEST_LARGE_DOCS_INDEX_NAME)) { + prepareKnnIndex( + TEST_LARGE_DOCS_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ) + ); + + // Index 1000 documents + for (int i = 0; i < 1000; i++) { + String docText; + if (i % 5 == 0) { + docText = TEST_DOC_TEXT1; // "Hello world" + } else if (i % 7 == 0) { + docText = TEST_DOC_TEXT2; // "Hi to this place" + } else if (i % 11 == 0) { + docText = TEST_DOC_TEXT3; // "We would like to welcome everyone" + } else { + docText = String.format(Locale.ROOT, "Document %d with random content", i); + } + + addKnnDoc( + TEST_LARGE_DOCS_INDEX_NAME, + String.valueOf(i), + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of( + Floats.asList(createRandomVector(TEST_DIMENSION)).toArray(), + Floats.asList(createRandomVector(TEST_DIMENSION)).toArray() + ), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(docText) + ); + } + assertEquals(1000, getDocCount(TEST_LARGE_DOCS_INDEX_NAME)); + } } private void addDocsToIndex(final String testMultiDocIndexName) { From e21d4eecd2e1315cc5a42218788affafb72bb9f9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 15 Nov 2024 10:41:39 -0800 Subject: [PATCH 11/11] Add and refactor integ tests Signed-off-by: Martin Gaievski --- .../processor/explain/ExplanationDetails.java | 3 +- .../query/HybridQueryExplainIT.java | 98 ++++++++++++++++--- .../neuralsearch/util/TestUtils.java | 9 +- 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index 2816a348b..c55db4426 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -21,7 +21,8 @@ public class ExplanationDetails { List> scoreDetails; public ExplanationDetails(List> scoreDetails) { - // pass docId as -1 to match docId in SearchHit https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchHit.java#L170 + // pass docId as -1 to match docId in SearchHit + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchHit.java#L170 this(-1, scoreDetails); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index 3b1d6cfba..b7e4f753a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -20,6 +20,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.IntStream; import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; @@ -33,6 +34,7 @@ import static org.opensearch.neuralsearch.util.TestUtils.getMaxScore; import static org.opensearch.neuralsearch.util.TestUtils.getNestedHits; import static org.opensearch.neuralsearch.util.TestUtils.getTotalHits; +import static org.opensearch.neuralsearch.util.TestUtils.getValueByKey; public class HybridQueryExplainIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index"; @@ -49,13 +51,17 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_TEXT_FIELD_NAME_2 = "test-text-field-2"; private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; - public static final String NORMALIZATION_TECHNIQUE_L2 = "l2"; + private static final String NORMALIZATION_TECHNIQUE_L2 = "l2"; + private static final int MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX = 2_000; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[768]; + @Before public void setUp() throws Exception { super.setUp(); @@ -114,7 +120,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { // explain Map searchHit1 = hitsNestedList.get(0); - Map topLevelExplanationsHit1 = (Map) searchHit1.get("_explanation"); + Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); assertNotNull(topLevelExplanationsHit1); assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedTopLevelDescription = "arithmetic_mean combination of:"; @@ -133,7 +139,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { // search hit 2 Map searchHit2 = hitsNestedList.get(1); - Map topLevelExplanationsHit2 = (Map) searchHit2.get("_explanation"); + Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); assertNotNull(topLevelExplanationsHit2); assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -158,7 +164,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { // search hit 3 Map searchHit3 = hitsNestedList.get(1); - Map topLevelExplanationsHit3 = (Map) searchHit3.get("_explanation"); + Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); assertNotNull(topLevelExplanationsHit3); assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -228,7 +234,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() // explain, hit 1 List> hitsNestedList = getNestedHits(searchResponseAsMap); Map searchHit1 = hitsNestedList.get(0); - Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + Map explanationForHit1 = getValueByKey(searchHit1, "_explanation"); assertNotNull(explanationForHit1); assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedTopLevelDescription = "arithmetic_mean, weights [0.3, 0.7] combination of:"; @@ -258,7 +264,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() // hit 2 Map searchHit2 = hitsNestedList.get(1); - Map explanationForHit2 = (Map) searchHit2.get("_explanation"); + Map explanationForHit2 = getValueByKey(searchHit2, "_explanation"); assertNotNull(explanationForHit2); assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -278,7 +284,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() // hit 3 Map searchHit3 = hitsNestedList.get(2); - Map explanationForHit3 = (Map) searchHit3.get("_explanation"); + Map explanationForHit3 = getValueByKey(searchHit3, "_explanation"); assertNotNull(explanationForHit3); assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -298,7 +304,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() // hit 4 Map searchHit4 = hitsNestedList.get(3); - Map explanationForHit4 = (Map) searchHit4.get("_explanation"); + Map explanationForHit4 = getValueByKey(searchHit4, "_explanation"); assertNotNull(explanationForHit4); assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -367,7 +373,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe // explain Map searchHit1 = hitsNestedList.get(0); - Map topLevelExplanationsHit1 = (Map) searchHit1.get("_explanation"); + Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); assertNotNull(topLevelExplanationsHit1); assertEquals(0.754f, (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedTopLevelDescription = "combined score of:"; @@ -409,7 +415,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe // search hit 2 Map searchHit2 = hitsNestedList.get(1); - Map topLevelExplanationsHit2 = (Map) searchHit2.get("_explanation"); + Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); assertNotNull(topLevelExplanationsHit2); assertEquals(0.287f, (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -434,7 +440,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe // search hit 3 Map searchHit3 = hitsNestedList.get(1); - Map topLevelExplanationsHit3 = (Map) searchHit3.get("_explanation"); + Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); assertNotNull(topLevelExplanationsHit3); assertEquals(0.287f, (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); @@ -476,7 +482,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { TEST_LARGE_DOCS_INDEX_NAME, hybridQueryBuilder, null, - 1000, + MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); @@ -497,7 +503,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { assertTrue("Score should be positive", score > 0.0); // Basic explanation structure checks - Map explanation = (Map) hit.get("_explanation"); + Map explanation = getValueByKey(hit, "_explanation"); assertNotNull(explanation); assertEquals("arithmetic_mean combination of:", explanation.get("description")); Map hitDetailsForHit = getListOfValues(explanation, "details").get(0); @@ -519,6 +525,66 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { } } + @SneakyThrows + public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); + hybridQueryBuilder.add( + KNNQueryBuilder.builder().k(10).fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).vector(TEST_VECTOR_SUPPLIER.get()).build() + ); + + Map searchResponseAsMap = search( + TEST_LARGE_DOCS_INDEX_NAME, + hybridQueryBuilder, + null, + MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + assertNotNull(hitsNestedList); + assertFalse(hitsNestedList.isEmpty()); + + // Verify total hits + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // Sanity checks for each hit's explanation + for (Map hit : hitsNestedList) { + // Verify score is positive + double score = (double) hit.get("_score"); + assertTrue("Score should be positive", score > 0.0); + + // Basic explanation structure checks + Map explanation = getValueByKey(hit, "_explanation"); + assertNotNull(explanation); + assertEquals("arithmetic_mean combination of:", explanation.get("description")); + Map hitDetailsForHit = getListOfValues(explanation, "details").get(0); + assertTrue((double) hitDetailsForHit.get("value") > 0.0f); + assertEquals("min_max normalization of:", hitDetailsForHit.get("description")); + Map subQueryDetailsForHit = getListOfValues(hitDetailsForHit, "details").get(0); + assertTrue((double) subQueryDetailsForHit.get("value") > 0.0f); + assertFalse(subQueryDetailsForHit.get("description").toString().isEmpty()); + assertNotNull(getListOfValues(subQueryDetailsForHit, "details")); + } + // Verify scores are properly ordered + List scores = new ArrayList<>(); + for (Map hit : hitsNestedList) { + scores.add((Double) hit.get("_score")); + } + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); + } finally { + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { @@ -591,8 +657,8 @@ private void initializeIndexIfNotExist(String indexName) { ) ); - // Index 1000 documents - for (int i = 0; i < 1000; i++) { + // Index large number of documents + for (int i = 0; i < MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX; i++) { String docText; if (i % 5 == 0) { docText = TEST_DOC_TEXT1; // "Hello world" @@ -616,7 +682,7 @@ private void initializeIndexIfNotExist(String indexName) { Collections.singletonList(docText) ); } - assertEquals(1000, getDocCount(TEST_LARGE_DOCS_INDEX_NAME)); + assertEquals(MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, getDocCount(TEST_LARGE_DOCS_INDEX_NAME)); } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index af9b37b14..bb072ab55 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -9,8 +9,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; -import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; import static org.opensearch.test.OpenSearchTestCase.randomFloat; import java.util.ArrayList; @@ -383,6 +381,13 @@ public static String getModelId(Map pipeline, String processor) return modelId; } + @SuppressWarnings("unchecked") + public static T getValueByKey(Map map, String key) { + assertNotNull(map); + Object value = map.get(key); + return (T) value; + } + public static String generateModelId() { return "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8); }