From 06532d10ec6beb5323ea42259104549e3e958ee3 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 19 Dec 2024 09:38:16 -0800 Subject: [PATCH] Add case for null/NaN scores and minor refactoring Signed-off-by: Martin Gaievski --- .../ExplanationResponseProcessor.java | 3 +- .../RRFNormalizationTechnique.java | 2 +- ...=> ExplanationResponseProcessorTests.java} | 116 +++++++++++++++++- .../RRFNormalizationTechniqueTests.java | 7 +- 4 files changed, 124 insertions(+), 4 deletions(-) rename src/test/java/org/opensearch/neuralsearch/processor/{ExplanationPayloadProcessorTests.java => ExplanationResponseProcessorTests.java} (76%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 01cdfcb0d..7a61519f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -111,8 +111,9 @@ public SearchResponse processResponse( ); } // Create and set final explanation combining all components + Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore(); Explanation finalExplanation = Explanation.match( - searchHit.getScore(), + finalScore, // combination level explanation is always a single detail combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index 4cc773592..80fc65eb3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -71,7 +71,7 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { @Override public String describe() { - return String.format(Locale.ROOT, "%s, rank_constant %s", TECHNIQUE_NAME, rankConstant); + return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java similarity index 76% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java index e47ea43d2..530753a96 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -37,9 +37,10 @@ import java.util.TreeMap; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -192,6 +193,119 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces assertOnExplanationResults(processedResponse, maxScore); } + @SneakyThrows + public void testProcessResponse_whenNullSearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = getSearchResponse(null); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenEmptySearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits emptyHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponse searchResponse = getSearchResponse(emptyHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenNullExplanation_thenSkipProcessing() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + for (SearchHit hit : searchHits.getHits()) { + hit.explanation(null); + } + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenInvalidExplanationPayload_thenHandleGracefully() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Set invalid payload + Map invalidPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + "invalid payload" + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(invalidPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenZeroScore_thenProcessCorrectly() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(0.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + assertEquals(0.0f, processedResponse.getHits().getMaxScore(), DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testProcessResponse_whenScoreIsNaN_thenExplanationUsesZero() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + // Create SearchHits with NaN score + SearchHits searchHits = getSearchHits(Float.NaN); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Setup explanation payload + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + // Process response + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + + // Verify results + assertNotNull(processedResponse); + SearchHit[] hits = processedResponse.getHits().getHits(); + assertNotNull(hits); + assertTrue(hits.length > 0); + + // Verify that the explanation uses 0.0f when input score was NaN + Explanation explanation = hits[0].getExplanation(); + assertNotNull(explanation); + assertEquals(0.0f, (float) explanation.getValue(), DELTA_FOR_FLOATS_ASSERTION); + } + private static SearchHits getSearchHits(float maxScore) { int numResponses = 1; int numIndices = 2; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java index 273d3d25f..da6d37bd7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -30,8 +30,13 @@ public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testDescribe() { + // verify with default values for parameters RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); - assertEquals("rrf, rank_constant 60", normalizationTechnique.describe()); + assertEquals("rrf, rank_constant [60]", normalizationTechnique.describe()); + + // verify when parameter values are set + normalizationTechnique = new RRFNormalizationTechnique(Map.of("rank_constant", 25), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [25]", normalizationTechnique.describe()); } public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() {