diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..a4b44d388 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features ### Enhancements +- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 5391faa14..e5272ab4b 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -32,7 +32,7 @@ import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; -import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; @@ -82,7 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); - public static final String PROCESSOR_EXPLAIN = "processor_explain"; + public static final String EXPLAIN_RESPONSE_KEY = "explain_response"; @Override public Collection createComponents( @@ -185,7 +185,7 @@ public Map explainPayload = processorExplainDto.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { Explanation processorExplanation = processorExplainDto.getExplanation(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 084a94bf7..5b6fa158c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -40,7 +40,7 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN; +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; @@ -152,7 +152,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(PROCESSOR_EXPLAIN, processorExplainDto); + pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 539eaa081..1d31b4c31 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -11,6 +11,7 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** @@ -20,7 +21,6 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "arithmetic_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 38f98c018..0dcd5c39c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -11,6 +11,7 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** @@ -20,7 +21,6 @@ public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "geometric_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index ddfbb5223..4fd112bc5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -11,6 +11,7 @@ import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique; /** @@ -20,7 +21,6 @@ public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "harmonic_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index a915057df..5f18baf09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -23,8 +23,8 @@ * Collection of utility methods for score combination technique classes */ @Log4j2 -class ScoreCombinationUtil { - private static final String PARAM_NAME_WEIGHTS = "weights"; +public class ScoreCombinationUtil { + public static final String PARAM_NAME_WEIGHTS = "weights"; private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 7314c3383..8194ecf74 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -5,11 +5,9 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.Objects; @@ -31,6 +29,8 @@ import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument; + /** * Abstracts combination of scores in query search results. */ @@ -64,7 +64,7 @@ public class ScoreCombiner { * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", * other steps are same for all techniques. * - * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. + * @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled. */ public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from @@ -106,7 +106,6 @@ private void combineShardScores( updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs, topDocsPerSubQuery, - normalizedScoresPerDoc, combinedNormalizedScoresByDocId, sortedDocsIds, getDocIdSortFieldsMap(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sort), @@ -129,7 +128,7 @@ private boolean isSortOrderByScore(Sort sort) { } /** - * @param sort sort criteria + * @param sort sort criteria * @param topDocsPerSubQuery top docs per subquery * @return list of top field docs which is deduced by typcasting top docs to top field docs. */ @@ -149,9 +148,9 @@ private List getTopFieldDocs(final Sort sort, final List } /** - * @param compoundTopDocs top docs that represent on shard + * @param compoundTopDocs top docs that represent on shard * @param combinedNormalizedScoresByDocId docId to normalized scores map - * @param sort sort criteria + * @param sort sort criteria * @return map of docId and sort fields if sorting is enabled. */ private Map getDocIdSortFieldsMap( @@ -290,7 +289,6 @@ private Map combineScoresAndGetCombinedNormalizedScoresPerDocume private void updateQueryTopDocsWithCombinedScores( final CompoundTopDocs compoundQueryTopDocs, final List topDocsPerSubQuery, - Map normalizedScoresPerDoc, final Map combinedNormalizedScoresByDocId, final Collection sortedScores, Map docIdSortFieldMap, @@ -322,21 +320,21 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon public Map> explain( final List queryTopDocs, - ScoreCombinationTechnique combinationTechnique, - Sort sort + final ScoreCombinationTechnique combinationTechnique, + final Sort sort ) { // In case of duplicate keys, keep the first value - HashMap> map = new HashMap<>(); + HashMap> explanations = new HashMap<>(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { for (Map.Entry> docIdAtSearchShardExplainDetailsEntry : explainByShard( combinationTechnique, compoundQueryTopDocs, sort ).entrySet()) { - map.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); + explanations.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue()); } } - return map; + return explanations; } private Map> explainByShard( @@ -349,31 +347,20 @@ private Map> explainByShard( } // - create map of normalized scores results returned from the single shard Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs()); - Map> explainsForShard = new HashMap<>(); Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - SearchShard searchShard = compoundQueryTopDocs.getSearchShard(); - explainsForShard.put(searchShard, new ArrayList<>()); - for (Integer docId : sortedDocsIds) { - float combinedScore = combinedNormalizedScoresByDocId.get(docId); - explainsForShard.get(searchShard) - .add( - new ExplainDetails( - combinedScore, - String.format( - Locale.ROOT, - "source scores: %s, combined score %s", - Arrays.toString(normalizedScoresPerDoc.get(docId)), - combinedScore - ), - docId - ) - ); - } - - return explainsForShard; + List listOfExplainsForShard = sortedDocsIds.stream() + .map( + docId -> getScoreCombinationExplainDetailsForDocument( + docId, + combinedNormalizedScoresByDocId, + normalizedScoresPerDoc.get(docId) + ) + ) + .toList(); + return Map.of(compoundQueryTopDocs.getSearchShard(), listOfExplainsForShard); } private Collection getSortedDocsIds( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java index 0de386ca7..a66acb786 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java @@ -5,16 +5,15 @@ package org.opensearch.neuralsearch.processor.explain; import org.apache.lucene.search.Explanation; -import org.opensearch.neuralsearch.processor.SearchShard; -import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + /** * Utility class for explain functionality */ @@ -37,65 +36,36 @@ public static Map getDocIdAtQueryForNormaliz List normScores = normalizedScores.get(entry.getKey()); return new ExplainDetails( normScores.stream().reduce(0.0f, Float::max), - String.format(Locale.ROOT, "source scores: %s, normalized scores: %s", srcScores, normScores) + String.format(Locale.ROOT, "source scores: %s normalized to scores: %s", srcScores, normScores) ); })); return explain; } /** - * Creates map of DocIdAtQueryPhase to String containing source scores and combined score - * @param scoreCombinationTechnique the combination technique used - * @param normalizedScoresPerDoc map of DocIdAtQueryPhase to normalized scores - * @param searchShard the search shard - * @return map of DocIdAtQueryPhase to String containing source scores and combined score + * Return the detailed score combination explain for the single document + * @param docId + * @param combinedNormalizedScoresByDocId + * @param normalizedScoresPerDoc + * @return */ - public static Map getDocIdAtQueryForCombination( - ScoreCombinationTechnique scoreCombinationTechnique, - Map normalizedScoresPerDoc, - SearchShard searchShard - ) { - Map explain = new HashMap<>(); - // - create map of combined scores per doc id - Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); - normalizedScoresPerDoc.forEach((key, srcScores) -> { - float combinedScore = combinedNormalizedScoresByDocId.get(key); - explain.put( - new DocIdAtSearchShard(key, searchShard), - new ExplainDetails( - combinedScore, - String.format(Locale.ROOT, "source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) - ) - ); - }); - return explain; - } - - /* public static Map> getExplainsByShardForCombination( - ScoreCombinationTechnique scoreCombinationTechnique, - Map normalizedScoresPerDoc, - SearchShard searchShard + public static ExplainDetails getScoreCombinationExplainDetailsForDocument( + final Integer docId, + final Map combinedNormalizedScoresByDocId, + final float[] normalizedScoresPerDoc ) { - Map explain = new HashMap<>(); - // - create map of combined scores per doc id - Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); - normalizedScoresPerDoc.forEach((key, srcScores) -> { - float combinedScore = combinedNormalizedScoresByDocId.get(key); - explain.put( - new DocIdAtSearchShard(key, searchShard), - new ExplainDetails( - combinedScore, - String.format("source scores: %s, combined score %s", Arrays.toString(srcScores), combinedScore) - ) - ); - }); - return explain; + float combinedScore = combinedNormalizedScoresByDocId.get(docId); + return new ExplainDetails( + combinedScore, + String.format( + Locale.ROOT, + "normalized scores: %s combined to a final score: %s", + Arrays.toString(normalizedScoresPerDoc), + combinedScore + ), + docId + ); } - */ /** * Creates a string describing the combination technique and its parameters @@ -104,7 +74,7 @@ public static Map getDocIdAtQueryForCombinat * @return a string describing the combination technique and its parameters */ public static String describeCombinationTechnique(final String techniqueName, final List weights) { - return String.format(Locale.ROOT, "combination technique [%s] with optional parameters [%s]", techniqueName, weights); + return String.format(Locale.ROOT, "combination [%s] with optional parameter [%s]: %s", techniqueName, PARAM_NAME_WEIGHTS, weights); } /** @@ -119,7 +89,7 @@ public static Explanation topLevelExpalantionForCombinedScore( ) { String explanationDetailsMessage = String.format( Locale.ROOT, - "combine score with techniques: %s, %s", + "combined score with techniques: %s, %s", explainableNormalizationTechnique.describe(), explainableCombinationTechnique.describe() ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java index 2633a89ad..cbca8daca 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java @@ -4,7 +4,7 @@ */ package org.opensearch.neuralsearch.processor.factory; -import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -21,6 +21,6 @@ public SearchResponseProcessor create( Map config, Processor.PipelineContext pipelineContext ) throws Exception { - return new ProcessorExplainPublisher(description, tag, ignoreFailure); + return new ExplainResponseProcessor(description, tag, ignoreFailure); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index e1eb6c56a..5c5436564 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -60,7 +60,7 @@ public void normalize(final List queryTopDocs) { @Override public String describe() { - return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); + return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 8e0f0f586..63efb4332 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -74,7 +74,7 @@ public void normalize(final List queryTopDocs) { @Override public String describe() { - return String.format(Locale.ROOT, "normalization technique [%s]", TECHNIQUE_NAME); + return String.format(Locale.ROOT, "normalization [%s]", TECHNIQUE_NAME); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 08393922a..c0c91be6e 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -140,11 +140,10 @@ public boolean isCacheable(LeafReaderContext ctx) { } /** - * Explain is not supported for hybrid query - * + * Returns a shard level {@link Explanation} that describes how the weight and scoring are calculated. * @param context the readers context to create the {@link Explanation} for. - * @param doc the document's id relative to the given context's reader - * @return + * @param doc the document's id relative to the given context's reader + * @return shard level {@link Explanation}, each sub-query explanation is a single nested element * @throws IOException */ @Override @@ -165,7 +164,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio } } if (match) { - final String desc = "combination of:"; + final String desc = "base scores from subqueries:"; return Explanation.match(max, desc, subsOnMatch); } else { return Explanation.noMatch("no matching clause", subsOnNoMatch); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java similarity index 63% rename from src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java index 0f5abfec4..796a899be 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ProcessorExplainPublisherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java @@ -13,21 +13,21 @@ import static org.mockito.Mockito.mock; -public class ProcessorExplainPublisherTests extends OpenSearchTestCase { +public class ExplainResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { - ProcessorExplainPublisher processorExplainPublisher = new ProcessorExplainPublisher(DESCRIPTION, PROCESSOR_TAG, false); + ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); - assertEquals(DESCRIPTION, processorExplainPublisher.getDescription()); - assertEquals(PROCESSOR_TAG, processorExplainPublisher.getTag()); - assertFalse(processorExplainPublisher.isIgnoreFailure()); + assertEquals(DESCRIPTION, explainResponseProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, explainResponseProcessor.getTag()); + assertFalse(explainResponseProcessor.isIgnoreFailure()); } @SneakyThrows public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { - ProcessorExplainPublisher processorExplainPublisher = new ProcessorExplainPublisher(DESCRIPTION, PROCESSOR_TAG, false); + ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); SearchRequest searchRequest = mock(SearchRequest.class); SearchResponse searchResponse = new SearchResponse( null, @@ -40,14 +40,14 @@ public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProc SearchResponse.Clusters.EMPTY ); - SearchResponse processedResponse = processorExplainPublisher.processResponse(searchRequest, searchResponse); + SearchResponse processedResponse = explainResponseProcessor.processResponse(searchRequest, searchResponse); assertEquals(searchResponse, processedResponse); - SearchResponse processedResponse2 = processorExplainPublisher.processResponse(searchRequest, searchResponse, null); + SearchResponse processedResponse2 = explainResponseProcessor.processResponse(searchRequest, searchResponse, null); assertEquals(searchResponse, processedResponse2); PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - SearchResponse processedResponse3 = processorExplainPublisher.processResponse( + SearchResponse processedResponse3 = explainResponseProcessor.processResponse( searchRequest, searchResponse, pipelineProcessingContext diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 7d3b3fb61..deac02933 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index 495e2f4cd..d46705902 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 0c6e1f81d..0cfdeb4c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index e24a4da5b..ce75d9fba 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -891,24 +891,26 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertNotNull(explanationForHit1); assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedGeneralCombineScoreDescription = - "combine score with techniques: normalization technique [min_max], combination technique [arithmetic_mean] with optional parameters [[]]"; + "combined score with techniques: normalization [min_max], combination [arithmetic_mean] with optional parameter [weights]: []"; assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); List> hit1Details = (List>) explanationForHit1.get("details"); assertEquals(3, hit1Details.size()); Map hit1DetailsForHit1 = hit1Details.get(0); assertEquals(1.0, hit1DetailsForHit1.get("value")); - assertTrue(((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[1\\.0\\]")); + assertTrue( + ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[1\\.0\\]") + ); assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); Map hit1DetailsForHit2 = hit1Details.get(1); assertEquals(0.5, hit1DetailsForHit2.get("value")); - assertEquals("source scores: [0.0, 1.0], combined score 0.5", hit1DetailsForHit2.get("description")); + assertEquals("normalized scores: [0.0, 1.0] combined to a final score: 0.5", hit1DetailsForHit2.get("description")); assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); Map hit1DetailsForHit3 = hit1Details.get(2); double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); assertTrue(actualHit1ScoreHit3 > 0.0); - assertEquals("combination of:", hit1DetailsForHit3.get("description")); + assertEquals("base scores from subqueries:", hit1DetailsForHit3.get("description")); assertEquals(1, ((List) hit1DetailsForHit3.get("details")).size()); Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); @@ -926,18 +928,20 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertEquals(3, hit2Details.size()); Map hit2DetailsForHit1 = hit2Details.get(0); assertEquals(1.0, hit2DetailsForHit1.get("value")); - assertTrue(((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[1\\.0\\]")); + assertTrue( + ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[1\\.0\\]") + ); assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); Map hit2DetailsForHit2 = hit2Details.get(1); assertEquals(0.5, hit2DetailsForHit2.get("value")); - assertEquals("source scores: [1.0, 0.0], combined score 0.5", hit2DetailsForHit2.get("description")); + assertEquals("normalized scores: [1.0, 0.0] combined to a final score: 0.5", hit2DetailsForHit2.get("description")); assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); Map hit2DetailsForHit3 = hit2Details.get(2); double actualHit2ScoreHit3 = ((double) hit2DetailsForHit3.get("value")); assertTrue(actualHit2ScoreHit3 > 0.0); - assertEquals("combination of:", hit2DetailsForHit3.get("description")); + assertEquals("base scores from subqueries:", hit2DetailsForHit3.get("description")); assertEquals(1, ((List) hit2DetailsForHit3.get("details")).size()); Map hit2SubDetailsForHit3 = (Map) ((List) hit2DetailsForHit3.get("details")).get(0); @@ -956,19 +960,19 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { Map hit3DetailsForHit1 = hit3Details.get(0); assertEquals(0.001, hit3DetailsForHit1.get("value")); assertTrue( - ((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[0\\.001\\]") + ((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[0\\.001\\]") ); assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); Map hit3DetailsForHit2 = hit3Details.get(1); assertEquals(0.0005, hit3DetailsForHit2.get("value")); - assertEquals("source scores: [0.0, 0.001], combined score 5.0E-4", hit3DetailsForHit2.get("description")); + assertEquals("normalized scores: [0.0, 0.001] combined to a final score: 5.0E-4", hit3DetailsForHit2.get("description")); assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); Map hit3DetailsForHit3 = hit3Details.get(2); double actualHit3ScoreHit3 = ((double) hit3DetailsForHit3.get("value")); assertTrue(actualHit3ScoreHit3 > 0.0); - assertEquals("combination of:", hit3DetailsForHit3.get("description")); + assertEquals("base scores from subqueries:", hit3DetailsForHit3.get("description")); assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); Map hit3SubDetailsForHit3 = (Map) ((List) hit3DetailsForHit3.get("details")).get(0); @@ -1027,27 +1031,28 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertNotNull(explanationForHit1); assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); String expectedGeneralCombineScoreDescription = - "combine score with techniques: normalization technique [l2], combination technique [arithmetic_mean] with optional parameters [" - + Arrays.toString(new float[] { 0.3f, 0.7f }) - + "]"; + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameter [weights]: " + + Arrays.toString(new float[] { 0.3f, 0.7f }); assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); List> hit1Details = (List>) explanationForHit1.get("details"); assertEquals(3, hit1Details.size()); Map hit1DetailsForHit1 = hit1Details.get(0); assertTrue((double) hit1DetailsForHit1.get("value") > 0.5f); assertTrue( - ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\], normalized scores: \\[.*, .*\\]") + ((String) hit1DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\] normalized to scores: \\[.*, .*\\]") ); assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); Map hit1DetailsForHit2 = hit1Details.get(1); assertEquals(actualMaxScore, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit1DetailsForHit2.get("description")).matches("source scores: \\[.*, .*\\], combined score .*")); + assertTrue( + ((String) hit1DetailsForHit2.get("description")).matches("normalized scores: \\[.*, .*\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); Map hit1DetailsForHit3 = hit1Details.get(2); assertEquals(1.0, (double) hit1DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit1DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit1DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); // hit 2 @@ -1062,18 +1067,20 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() Map hit2DetailsForHit1 = hit2Details.get(0); assertTrue((double) hit2DetailsForHit1.get("value") > 0.5f); assertTrue( - ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\], normalized scores: \\[.*, .*\\]") + ((String) hit2DetailsForHit1.get("description")).matches("source scores: \\[1.0, .*\\] normalized to scores: \\[.*, .*\\]") ); assertEquals(0, ((List) hit2DetailsForHit1.get("details")).size()); Map hit2DetailsForHit2 = hit2Details.get(1); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit2DetailsForHit2.get("value"))); - assertTrue(((String) hit2DetailsForHit2.get("description")).matches("source scores: \\[.*, .*\\], combined score .*")); + assertTrue( + ((String) hit2DetailsForHit2.get("description")).matches("normalized scores: \\[.*, .*\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit2DetailsForHit2.get("details")).size()); Map hit2DetailsForHit3 = hit2Details.get(2); assertEquals(1.0, (double) hit2DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit2DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit2DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(2, ((List) hit2DetailsForHit3.get("details")).size()); // hit 3 @@ -1087,17 +1094,19 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(3, hit3Details.size()); Map hit3DetailsForHit1 = hit3Details.get(0); assertTrue((double) hit3DetailsForHit1.get("value") > 0.5f); - assertTrue(((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\], normalized scores: \\[.*\\]")); + assertTrue(((String) hit3DetailsForHit1.get("description")).matches("source scores: \\[.*\\] normalized to scores: \\[.*\\]")); assertEquals(0, ((List) hit3DetailsForHit1.get("details")).size()); Map hit3DetailsForHit2 = hit3Details.get(1); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit2.get("value"))); - assertTrue(((String) hit3DetailsForHit2.get("description")).matches("source scores: \\[0.0, .*\\], combined score .*")); + assertTrue( + ((String) hit3DetailsForHit2.get("description")).matches("normalized scores: \\[0.0, .*\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit3DetailsForHit2.get("details")).size()); Map hit3DetailsForHit3 = hit3Details.get(2); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit3DetailsForHit3.get("value"))); - assertTrue(((String) hit3DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit3DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(1, ((List) hit3DetailsForHit3.get("details")).size()); // hit 4 @@ -1111,17 +1120,19 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(3, hit4Details.size()); Map hit4DetailsForHit1 = hit4Details.get(0); assertTrue((double) hit4DetailsForHit1.get("value") > 0.5f); - assertTrue(((String) hit4DetailsForHit1.get("description")).matches("source scores: \\[1.0\\], normalized scores: \\[.*\\]")); + assertTrue(((String) hit4DetailsForHit1.get("description")).matches("source scores: \\[1.0\\] normalized to scores: \\[.*\\]")); assertEquals(0, ((List) hit4DetailsForHit1.get("details")).size()); Map hit4DetailsForHit2 = hit4Details.get(1); assertTrue(Range.of(0.0, (double) actualMaxScore).contains((double) hit4DetailsForHit2.get("value"))); - assertTrue(((String) hit4DetailsForHit2.get("description")).matches("source scores: \\[.*, 0.0\\], combined score .*")); + assertTrue( + ((String) hit4DetailsForHit2.get("description")).matches("normalized scores: \\[.*, 0.0\\] combined to a final score: .*") + ); assertEquals(0, ((List) hit4DetailsForHit2.get("details")).size()); Map hit4DetailsForHit3 = hit4Details.get(2); assertEquals(1.0, (double) hit4DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertTrue(((String) hit4DetailsForHit3.get("description")).matches("combination of:")); + assertTrue(((String) hit4DetailsForHit3.get("description")).matches("base scores from subqueries:")); assertEquals(1, ((List) hit4DetailsForHit3.get("details")).size()); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index b5e812780..86f2dc620 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -21,6 +21,9 @@ import org.opensearch.neuralsearch.BaseNeuralSearchIT; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQueryWhenSortIsEnabled; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import org.opensearch.search.sort.SortOrder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; @@ -531,6 +534,152 @@ public void testSortingWithRescoreWhenConcurrentSegmentSearchEnabledAndDisabled_ } } + @SneakyThrows + public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { + try { + // Setup + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + // Assert + // scores for search hits + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + null, + 0 + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6); + assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, true); + + // explain + Map searchHit1 = nestedHits.get(0); + Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(explanationForHit1); + assertNull(searchHit1.get("_score")); + String expectedGeneralCombineScoreDescription = + "combined score with techniques: normalization [min_max], combination [arithmetic_mean] with optional parameter [weights]: []"; + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); + List> hit1Details = (List>) explanationForHit1.get("details"); + assertEquals(3, hit1Details.size()); + Map hit1DetailsForHit1 = hit1Details.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + assertTrue( + ((String) hit1DetailsForHit1.get("description")).matches( + "source scores: \\[0.4700036, 1.0\\] normalized to scores: \\[1.0, 1.0\\]" + ) + ); + assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertEquals(0.6666667, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("normalized scores: [1.0, 0.0, 1.0] combined to a final score: 0.6666667", hit1DetailsForHit2.get("description")); + assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); + + Map hit1DetailsForHit3 = hit1Details.get(2); + double actualHit1ScoreHit3 = ((double) hit1DetailsForHit3.get("value")); + assertTrue(actualHit1ScoreHit3 > 0.0); + assertEquals("base scores from subqueries:", hit1DetailsForHit3.get("description")); + assertEquals(2, ((List) hit1DetailsForHit3.get("details")).size()); + + Map hit1SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(0); + assertEquals(0.47, ((double) hit1SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(name:mission in 0) [PerFieldSimilarity], result of:", hit1SubDetailsForHit3.get("description")); + assertEquals(1, ((List) hit1SubDetailsForHit3.get("details")).size()); + + Map hit2SubDetailsForHit3 = (Map) ((List) hit1DetailsForHit3.get("details")).get(1); + assertEquals(1.0f, ((double) hit2SubDetailsForHit3.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("stock:[20 TO 400]", hit2SubDetailsForHit3.get("description")); + assertEquals(0, ((List) hit2SubDetailsForHit3.get("details")).size()); + // hit 4 + Map searchHit4 = nestedHits.get(3); + Map explanationForHit4 = (Map) searchHit4.get("_explanation"); + assertNotNull(explanationForHit4); + assertNull(searchHit4.get("_score")); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit4.get("description")); + List> hit4Details = (List>) explanationForHit4.get("details"); + assertEquals(3, hit4Details.size()); + Map hit1DetailsForHit4 = hit4Details.get(0); + assertEquals(1.0, hit1DetailsForHit4.get("value")); + assertTrue( + ((String) hit1DetailsForHit4.get("description")).matches( + "source scores: \\[0.9808291, 1.0\\] normalized to scores: \\[1.0, 1.0\\]" + ) + ); + assertEquals(0, ((List) hit1DetailsForHit4.get("details")).size()); + + Map hit2DetailsForHit4 = hit4Details.get(1); + assertEquals(0.6666667, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("normalized scores: [0.0, 1.0, 1.0] combined to a final score: 0.6666667", hit2DetailsForHit4.get("description")); + assertEquals(0, ((List) hit2DetailsForHit4.get("details")).size()); + + Map hit3DetailsForHit4 = hit4Details.get(2); + double actualHit3ScoreHit4 = ((double) hit3DetailsForHit4.get("value")); + assertTrue(actualHit3ScoreHit4 > 0.0); + assertEquals("base scores from subqueries:", hit3DetailsForHit4.get("description")); + assertEquals(2, ((List) hit3DetailsForHit4.get("details")).size()); + + Map hit1SubDetailsForHit4 = (Map) ((List) hit3DetailsForHit4.get("details")).get(0); + assertEquals(0.98, ((double) hit1SubDetailsForHit4.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1SubDetailsForHit4.get("description")); + assertEquals(1, ((List) hit1SubDetailsForHit4.get("details")).size()); + + Map hit2SubDetailsForHit4 = (Map) ((List) hit3DetailsForHit4.get("details")).get(1); + assertEquals(1.0f, ((double) hit2SubDetailsForHit4.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("stock:[20 TO 400]", hit2SubDetailsForHit4.get("description")); + assertEquals(0, ((List) hit2SubDetailsForHit4.get("details")).size()); + + // hit 6 + Map searchHit6 = nestedHits.get(5); + Map explanationForHit6 = (Map) searchHit6.get("_explanation"); + assertNotNull(explanationForHit6); + assertNull(searchHit6.get("_score")); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit6.get("description")); + List> hit6Details = (List>) explanationForHit6.get("details"); + assertEquals(3, hit6Details.size()); + Map hit1DetailsForHit6 = hit6Details.get(0); + assertEquals(1.0, hit1DetailsForHit6.get("value")); + assertEquals("source scores: [1.0] normalized to scores: [1.0]", hit1DetailsForHit6.get("description")); + assertEquals(0, ((List) hit1DetailsForHit6.get("details")).size()); + + Map hit2DetailsForHit6 = hit6Details.get(1); + assertEquals(0.333, (double) hit2DetailsForHit6.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("normalized scores: [0.0, 0.0, 1.0] combined to a final score: 0.33333334", hit2DetailsForHit6.get("description")); + assertEquals(0, ((List) hit2DetailsForHit6.get("details")).size()); + + Map hit3DetailsForHit6 = hit6Details.get(2); + double actualHit3ScoreHit6 = ((double) hit3DetailsForHit6.get("value")); + assertTrue(actualHit3ScoreHit6 > 0.0); + assertEquals("base scores from subqueries:", hit3DetailsForHit6.get("description")); + assertEquals(1, ((List) hit3DetailsForHit6.get("details")).size()); + + Map hit1SubDetailsForHit6 = (Map) ((List) hit3DetailsForHit6.get("details")).get(0); + assertEquals(1.0, ((double) hit1SubDetailsForHit6.get("value")), DELTA_FOR_SCORE_ASSERTION); + assertEquals("stock:[20 TO 400]", hit1SubDetailsForHit6.get("description")); + assertEquals(0, ((List) hit1SubDetailsForHit6.get("details")).size()); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 9eb802708..1d666d788 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -49,7 +49,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.processor.NormalizationProcessor; -import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher; +import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import org.opensearch.search.sort.SortBuilder; @@ -1203,7 +1203,7 @@ protected void createSearchPipeline( if (addExplainResponseProcessor) { stringBuilderForContentBody.append(", \"response_processors\": [ ") .append("{\"") - .append(ProcessorExplainPublisher.TYPE) + .append(ExplainResponseProcessor.TYPE) .append("\": {}}") .append("]"); }