From a19de090ec35db20d7559d466954b7e672ae8982 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 1 Nov 2024 13:58:04 -0700 Subject: [PATCH] Doing some refactoring Signed-off-by: Martin Gaievski --- .../neuralsearch/plugin/NeuralSearch.java | 8 +- .../processor/CompoundTopDocs.java | 6 +- ...java => ExplanationResponseProcessor.java} | 21 +- .../processor/NormalizationProcessor.java | 22 +- .../NormalizationProcessorWorkflow.java | 8 +- ...zationProcessorWorkflowExecuteRequest.java | 3 + .../neuralsearch/processor/SearchShard.java | 10 +- .../explain/CombinedExplainDetails.java | 3 + .../processor/explain/DocIdAtSearchShard.java | 2 +- .../processor/explain/ExplainDetails.java | 7 +- .../processor/explain/ExplainUtils.java | 5 +- ...plainDto.java => ExplanationResponse.java} | 5 +- ... ExplanationResponseProcessorFactory.java} | 9 +- .../ExplainResponseProcessorTests.java | 57 -- .../ExplanationResponseProcessorTests.java | 530 ++++++++++++++++++ .../neuralsearch/query/HybridQuerySortIT.java | 4 +- .../neuralsearch/BaseNeuralSearchIT.java | 4 +- 17 files changed, 601 insertions(+), 103 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/{ExplainResponseProcessor.java => ExplanationResponseProcessor.java} (83%) rename src/main/java/org/opensearch/neuralsearch/processor/explain/{ProcessorExplainDto.java => ExplanationResponse.java} (80%) rename src/main/java/org/opensearch/neuralsearch/processor/factory/{ProcessorExplainPublisherFactory.java => ExplanationResponseProcessorFactory.java} (65%) delete mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index e5272ab4b..8deabd141 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -32,14 +32,14 @@ import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; -import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; -import org.opensearch.neuralsearch.processor.factory.ProcessorExplainPublisherFactory; +import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; @@ -185,8 +185,8 @@ public Map topDocs, boolean isSo public CompoundTopDocs(final QuerySearchResult querySearchResult) { final TopDocs topDocs = querySearchResult.topDocs().topDocs; final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget(); - SearchShard searchShard = new SearchShard( - searchShardTarget.getIndex(), - searchShardTarget.getShardId().id(), - searchShardTarget.getNodeId() - ); + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); boolean isSortEnabled = false; if (topDocs instanceof TopFieldDocs) { isSortEnabled = true; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java similarity index 83% rename from src/main/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 4aee8fe6e..4cfaf9837 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -10,7 +10,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; -import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto; +import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; @@ -24,13 +24,16 @@ import java.util.Objects; import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY; -import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR; +import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR; +/** + * Processor to add explanation details to search response + */ @Getter @AllArgsConstructor -public class ExplainResponseProcessor implements SearchResponseProcessor { +public class ExplanationResponseProcessor implements SearchResponseProcessor { - public static final String TYPE = "explain_response_processor"; + public static final String TYPE = "explanation_response_processor"; private final String description; private final String tag; @@ -46,10 +49,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) { return response; } - ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); - Map explainPayload = processorExplainDto.getExplainPayload(); + ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY); + Map explainPayload = explanationResponse.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { - Explanation processorExplanation = processorExplainDto.getExplanation(); + Explanation processorExplanation = explanationResponse.getExplanation(); if (Objects.isNull(processorExplanation)) { return response; } @@ -62,7 +65,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp for (int i = 0; i < searchHitsArray.length; i++) { SearchHit searchHit = searchHitsArray[i]; SearchShardTarget searchShardTarget = searchHit.getShard(); - SearchShard searchShard = SearchShard.create(searchShardTarget); + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); explainsByShardCount.putIfAbsent(searchShard, -1); } @@ -73,7 +76,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp List>) explainPayload.get(NORMALIZATION_PROCESSOR); for (SearchHit searchHit : searchHitsArray) { - SearchShard searchShard = SearchShard.create(searchHit.getShard()); + SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); Explanation normalizedExplanation = Explanation.match( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 6c6327b3b..7f0314ef7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -44,7 +44,7 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { /** * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor + * are set as part of class constructor. This method is called when there is no pipeline context * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution * @param searchPhaseContext {@link SearchContext} */ @@ -53,19 +53,27 @@ public void process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext ) { - doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.empty()); + prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); } + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ @Override public void process( - SearchPhaseResults searchPhaseResult, - SearchPhaseContext searchPhaseContext, - PipelineProcessingContext requestContext + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext ) { - doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); } - private void doProcessStuff( + private void prepareAndExecuteNormalizationWorkflow( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional requestContextOptional diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 5b6fa158c..118d0a25c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -28,7 +28,7 @@ import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplainDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto; +import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -146,13 +146,13 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< } }); - ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder() + ExplanationResponse explanationResponse = ExplanationResponse.builder() .explanation(topLevelExplanationForTechniques) - .explainPayload(Map.of(ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) + .explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain)) .build(); // store explain object to pipeline context PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); - pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto); + pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java index 8056bd100..ea0b54b9c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -19,6 +19,9 @@ @Builder @AllArgsConstructor @Getter +/** + * DTO class to hold request parameters for normalization and combination + */ public class NormalizationProcessorWorkflowExecuteRequest { final List querySearchResults; final Optional fetchSearchResultOptional; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java index 61f9d0be5..505b19ae0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -6,9 +6,17 @@ import org.opensearch.search.SearchShardTarget; +/** + * DTO class to store index, shardId and nodeId for a search shard. + */ public record SearchShard(String index, int shardId, String nodeId) { - public static SearchShard create(SearchShardTarget searchShardTarget) { + /** + * Create SearchShard from SearchShardTarget + * @param searchShardTarget + * @return SearchShard + */ + public static SearchShard createSearchShard(final SearchShardTarget searchShardTarget) { return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java index 8793f5d53..4a9793fd4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplainDetails.java @@ -8,6 +8,9 @@ import lombok.Builder; import lombok.Getter; +/** + * DTO class to hold explain details for normalization and combination + */ @AllArgsConstructor @Builder @Getter diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java index 70da8b73c..9ce4ebf97 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java @@ -7,7 +7,7 @@ import org.opensearch.neuralsearch.processor.SearchShard; /** - * Data class to store docId and search shard for a query. + * DTO class to store docId and search shard for a query. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. * @param docId * @param searchShard diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java index 0bbdea27f..ca83cc436 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainDetails.java @@ -5,14 +5,13 @@ package org.opensearch.neuralsearch.processor.explain; /** - * Data class to store value and description for explain details. + * DTO class to store value and description for explain details. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. * @param value * @param description */ -public record ExplainDetails(float value, String description, int docId) { - +public record ExplainDetails(int docId, float value, String description) { public ExplainDetails(float value, String description) { - this(value, description, -1); + this(-1, value, description); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java index a66acb786..1eca4232f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainUtils.java @@ -56,14 +56,14 @@ public static ExplainDetails getScoreCombinationExplainDetailsForDocument( ) { float combinedScore = combinedNormalizedScoresByDocId.get(docId); return new ExplainDetails( + docId, combinedScore, String.format( Locale.ROOT, "normalized scores: %s combined to a final score: %s", Arrays.toString(normalizedScoresPerDoc), combinedScore - ), - docId + ) ); } @@ -96,5 +96,4 @@ public static Explanation topLevelExpalantionForCombinedScore( return Explanation.match(0.0f, explanationDetailsMessage); } - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java similarity index 80% rename from src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java rename to src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java index ed29cd6e3..f14050214 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ProcessorExplainDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationResponse.java @@ -11,10 +11,13 @@ import java.util.Map; +/** + * DTO class to hold explain details for normalization and combination + */ @AllArgsConstructor @Builder @Getter -public class ProcessorExplainDto { +public class ExplanationResponse { Explanation explanation; Map explainPayload; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java similarity index 65% rename from src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java rename to src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java index cbca8daca..ffe1da12f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/ProcessorExplainPublisherFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java @@ -4,13 +4,16 @@ */ package org.opensearch.neuralsearch.processor.factory; -import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; import java.util.Map; -public class ProcessorExplainPublisherFactory implements Processor.Factory { +/** + * Factory class for creating ExplanationResponseProcessor + */ +public class ExplanationResponseProcessorFactory implements Processor.Factory { @Override public SearchResponseProcessor create( @@ -21,6 +24,6 @@ public SearchResponseProcessor create( Map config, Processor.PipelineContext pipelineContext ) throws Exception { - return new ExplainResponseProcessor(description, tag, ignoreFailure); + return new ExplanationResponseProcessor(description, tag, ignoreFailure); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java deleted file mode 100644 index 796a899be..000000000 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplainResponseProcessorTests.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.processor; - -import lombok.SneakyThrows; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.search.pipeline.PipelineProcessingContext; -import org.opensearch.test.OpenSearchTestCase; - -import static org.mockito.Mockito.mock; - -public class ExplainResponseProcessorTests extends OpenSearchTestCase { - private static final String PROCESSOR_TAG = "mockTag"; - private static final String DESCRIPTION = "mockDescription"; - - public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { - ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); - - assertEquals(DESCRIPTION, explainResponseProcessor.getDescription()); - assertEquals(PROCESSOR_TAG, explainResponseProcessor.getTag()); - assertFalse(explainResponseProcessor.isIgnoreFailure()); - } - - @SneakyThrows - public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { - ExplainResponseProcessor explainResponseProcessor = new ExplainResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); - SearchRequest searchRequest = mock(SearchRequest.class); - SearchResponse searchResponse = new SearchResponse( - null, - null, - 1, - 1, - 0, - 1000, - new ShardSearchFailure[0], - SearchResponse.Clusters.EMPTY - ); - - SearchResponse processedResponse = explainResponseProcessor.processResponse(searchRequest, searchResponse); - assertEquals(searchResponse, processedResponse); - - SearchResponse processedResponse2 = explainResponseProcessor.processResponse(searchRequest, searchResponse, null); - assertEquals(searchResponse, processedResponse2); - - PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); - SearchResponse processedResponse3 = explainResponseProcessor.processResponse( - searchRequest, - searchResponse, - pipelineProcessingContext - ); - assertEquals(searchResponse, processedResponse3); - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java new file mode 100644 index 000000000..ce2df0b13 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -0,0 +1,530 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.commons.lang3.Range; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationResponse; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.RemoteClusterAware; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.TreeMap; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; + +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + + assertEquals(DESCRIPTION, explanationResponseProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, explanationResponseProcessor.getTag()); + assertFalse(explanationResponseProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = new SearchResponse( + null, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + + SearchResponse processedResponse = explanationResponseProcessor.processResponse(searchRequest, searchResponse); + assertEquals(searchResponse, processedResponse); + + SearchResponse processedResponse2 = explanationResponseProcessor.processResponse(searchRequest, searchResponse, null); + assertEquals(searchResponse, processedResponse2); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + SearchResponse processedResponse3 = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + assertEquals(searchResponse, processedResponse3); + } + + @SneakyThrows + public void testParsingOfExplanations_whenResponseHasExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(null)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + float maxScore = numDocs * scoreFactor; + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + null, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(numResponses, TotalHits.Relation.EQUAL_TO), 1.0f); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Explanation generalExplanation = Explanation.match( + maxScore, + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" + ); + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHitArray[0].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .build() + ), + SearchShard.createSearchShard(searchHitArray[1].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .build() + ) + ); + Map explainPayload = Map.of( + ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationResponse explanationResponse = ExplanationResponse.builder() + .explanation(generalExplanation) + .explainPayload(explainPayload) + .build(); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + @SneakyThrows + public void testParsingOfExplanations_whenFieldSortingAndExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + final SortField[] sortFields = new SortField[] { + new SortField("random-text-field-1", SortField.Type.INT, randomBoolean()), + new SortField("random-text-field-2", SortField.Type.STRING, randomBoolean()) }; + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + float maxScore = Float.NaN; + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + sortFields, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, totalHits, maxScore, sortFields, null, null); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Explanation generalExplanation = Explanation.match( + maxScore, + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" + ); + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHitArray[0].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .build() + ), + SearchShard.createSearchShard(searchHitArray[1].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .build() + ) + ); + Map explainPayload = Map.of( + ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationResponse explanationResponse = ExplanationResponse.builder() + .explanation(generalExplanation) + .explainPayload(explainPayload) + .build(); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + @SneakyThrows + public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + final SortField[] sortFields = new SortField[] { SortField.FIELD_SCORE }; + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + float maxScore = Float.NaN; + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + sortFields, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "base scores from subqueries:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, totalHits, maxScore, sortFields, null, null); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Explanation generalExplanation = Explanation.match( + maxScore, + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]" + ); + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHitArray[0].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(1.0f, "source scores: [1.0] normalized to scores: [0.5]")) + .combinationExplain(new ExplainDetails(0.5f, "normalized scores: [0.5] combined to a final score: 0.5")) + .build() + ), + SearchShard.createSearchShard(searchHitArray[1].getShard()), + List.of( + CombinedExplainDetails.builder() + .normalizationExplain(new ExplainDetails(0.5f, "source scores: [0.5] normalized to scores: [0.25]")) + .combinationExplain(new ExplainDetails(0.25f, "normalized scores: [0.25] combined to a final score: 0.25")) + .build() + ) + ); + Map explainPayload = Map.of( + ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationResponse explanationResponse = ExplanationResponse.builder() + .explanation(generalExplanation) + .explainPayload(explainPayload) + .build(); + pipelineProcessingContext.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY, explanationResponse); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + private static SearchResponse getSearchResponse(SearchHits searchHits) { + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + internalSearchResponse, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + return searchResponse; + } + + private static void assertOnExplanationResults(SearchResponse processedResponse, float maxScore) { + assertNotNull(processedResponse); + Explanation explanationHit1 = processedResponse.getHits().getHits()[0].getExplanation(); + assertNotNull(explanationHit1); + assertEquals( + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]", + explanationHit1.getDescription() + ); + assertEquals(maxScore, (float) explanationHit1.getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation[] detailsHit1 = explanationHit1.getDetails(); + assertEquals(3, detailsHit1.length); + assertEquals("source scores: [1.0] normalized to scores: [0.5]", detailsHit1[0].getDescription()); + assertEquals(1.0f, (float) detailsHit1[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("normalized scores: [0.5] combined to a final score: 0.5", detailsHit1[1].getDescription()); + assertEquals(0.5f, (float) detailsHit1[1].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("base scores from subqueries:", detailsHit1[2].getDescription()); + assertEquals(1.0f, (float) detailsHit1[2].getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation explanationHit2 = processedResponse.getHits().getHits()[1].getExplanation(); + assertNotNull(explanationHit2); + assertEquals( + "combined score with techniques: normalization [l2], combination [arithmetic_mean] with optional parameters [[]]", + explanationHit2.getDescription() + ); + assertTrue(Range.of(0.0f, maxScore).contains((float) explanationHit2.getValue())); + + Explanation[] detailsHit2 = explanationHit2.getDetails(); + assertEquals(3, detailsHit2.length); + assertEquals("source scores: [0.5] normalized to scores: [0.25]", detailsHit2[0].getDescription()); + assertEquals(.5f, (float) detailsHit2[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("normalized scores: [0.25] combined to a final score: 0.25", detailsHit2[1].getDescription()); + assertEquals(.25f, (float) detailsHit2[1].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertEquals("base scores from subqueries:", detailsHit2[2].getDescription()); + assertEquals(1.0f, (float) detailsHit2[2].getValue(), DELTA_FOR_SCORE_ASSERTION); + } + + private static Map randomRealisticIndices(int numIndices, int numClusters) { + String[] indicesNames = new String[numIndices]; + for (int i = 0; i < numIndices; i++) { + indicesNames[i] = randomAlphaOfLengthBetween(5, 10); + } + Map indicesPerCluster = new TreeMap<>(); + for (int i = 0; i < numClusters; i++) { + Index[] indices = new Index[indicesNames.length]; + for (int j = 0; j < indices.length; j++) { + String indexName = indicesNames[j]; + String indexUuid = frequently() ? randomAlphaOfLength(10) : indexName; + indices[j] = new Index(indexName, indexUuid); + } + String clusterAlias; + if (frequently() || indicesPerCluster.containsKey(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY)) { + clusterAlias = randomAlphaOfLengthBetween(5, 10); + } else { + clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + } + indicesPerCluster.put(clusterAlias, indices); + } + return indicesPerCluster; + } + + private static SearchHit[] randomSearchHitArray( + int numDocs, + int numResponses, + String clusterAlias, + Index[] indices, + float maxScore, + int scoreFactor, + SortField[] sortFields, + PriorityQueue priorityQueue + ) { + SearchHit[] hits = new SearchHit[numDocs]; + + int[] sortFieldFactors = new int[sortFields == null ? 0 : sortFields.length]; + for (int j = 0; j < sortFieldFactors.length; j++) { + sortFieldFactors[j] = randomIntBetween(1, numResponses); + } + + for (int j = 0; j < numDocs; j++) { + ShardId shardId = new ShardId(randomFrom(indices), randomIntBetween(0, 10)); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLengthBetween(3, 8), + shardId, + clusterAlias, + OriginalIndices.NONE + ); + SearchHit hit = new SearchHit(randomIntBetween(0, Integer.MAX_VALUE)); + + float score = Float.NaN; + if (!Float.isNaN(maxScore)) { + score = (maxScore - j) * scoreFactor; + hit.score(score); + } + + hit.shard(shardTarget); + if (sortFields != null) { + Object[] rawSortValues = new Object[sortFields.length]; + DocValueFormat[] docValueFormats = new DocValueFormat[sortFields.length]; + for (int k = 0; k < sortFields.length; k++) { + SortField sortField = sortFields[k]; + if (sortField == SortField.FIELD_SCORE) { + hit.score(score); + rawSortValues[k] = score; + } else { + rawSortValues[k] = sortField.getReverse() ? numDocs * sortFieldFactors[k] - j : j; + } + docValueFormats[k] = DocValueFormat.RAW; + } + hit.sortValues(rawSortValues, docValueFormats); + } + hits[j] = hit; + priorityQueue.add(hit); + } + return hits; + } + + private static final class SearchHitComparator implements Comparator { + + private final SortField[] sortFields; + + SearchHitComparator(SortField[] sortFields) { + this.sortFields = sortFields; + } + + @Override + public int compare(SearchHit a, SearchHit b) { + if (sortFields == null) { + int scoreCompare = Float.compare(b.getScore(), a.getScore()); + if (scoreCompare != 0) { + return scoreCompare; + } + } else { + for (int i = 0; i < sortFields.length; i++) { + SortField sortField = sortFields[i]; + if (sortField == SortField.FIELD_SCORE) { + int scoreCompare = Float.compare(b.getScore(), a.getScore()); + if (scoreCompare != 0) { + return scoreCompare; + } + } else { + Integer aSortValue = (Integer) a.getRawSortValues()[i]; + Integer bSortValue = (Integer) b.getRawSortValues()[i]; + final int compare; + if (sortField.getReverse()) { + compare = Integer.compare(bSortValue, aSortValue); + } else { + compare = Integer.compare(aSortValue, bSortValue); + } + if (compare != 0) { + return compare; + } + } + } + } + SearchShardTarget aShard = a.getShard(); + SearchShardTarget bShard = b.getShard(); + int shardIdCompareTo = aShard.getShardId().compareTo(bShard.getShardId()); + if (shardIdCompareTo != 0) { + return shardIdCompareTo; + } + int clusterAliasCompareTo = aShard.getClusterAlias().compareTo(bShard.getClusterAlias()); + if (clusterAliasCompareTo != 0) { + return clusterAliasCompareTo; + } + return Integer.compare(a.docId(), b.docId()); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index 86f2dc620..489ebc520 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -590,7 +590,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { assertEquals(0, ((List) hit1DetailsForHit1.get("details")).size()); Map hit1DetailsForHit2 = hit1Details.get(1); - assertEquals(0.6666667, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.666, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); assertEquals("normalized scores: [1.0, 0.0, 1.0] combined to a final score: 0.6666667", hit1DetailsForHit2.get("description")); assertEquals(0, ((List) hit1DetailsForHit2.get("details")).size()); @@ -627,7 +627,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { assertEquals(0, ((List) hit1DetailsForHit4.get("details")).size()); Map hit2DetailsForHit4 = hit4Details.get(1); - assertEquals(0.6666667, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.666, (double) hit2DetailsForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); assertEquals("normalized scores: [0.0, 1.0, 1.0] combined to a final score: 0.6666667", hit2DetailsForHit4.get("description")); assertEquals(0, ((List) hit2DetailsForHit4.get("details")).size()); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 1d666d788..5761c5569 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -49,7 +49,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.processor.NormalizationProcessor; -import org.opensearch.neuralsearch.processor.ExplainResponseProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import org.opensearch.search.sort.SortBuilder; @@ -1203,7 +1203,7 @@ protected void createSearchPipeline( if (addExplainResponseProcessor) { stringBuilderForContentBody.append(", \"response_processors\": [ ") .append("{\"") - .append(ExplainResponseProcessor.TYPE) + .append(ExplanationResponseProcessor.TYPE) .append("\": {}}") .append("]"); }