From 85c0d049e0bf3411257887d559022abf11fa6873 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Mon, 27 May 2024 14:50:36 +0200 Subject: [PATCH] Better handling of multiple rescorers clauses with LTR. --- .../action/search/SearchRequest.java | 46 ++++++++++++ .../search/rescore/QueryRescorer.java | 4 +- .../search/rescore/QueryRescorerBuilder.java | 2 +- .../search/rescore/RescoreContext.java | 9 ++- .../search/rescore/RescorePhase.java | 22 ++++++ .../search/rescore/RescorerBuilder.java | 33 ++++++--- .../action/search/SearchRequestTests.java | 60 ++++++++++++++++ .../rescore/QueryRescorerBuilderTests.java | 1 + .../integration/LearningToRankRescorerIT.java | 71 +++++++++++++------ .../inference/ltr/LearningToRankRescorer.java | 31 ++------ .../ltr/LearningToRankRescorerBuilder.java | 6 +- .../ltr/LearningToRankRescorerContext.java | 4 +- ...ningToRankRescorerBuilderRewriteTests.java | 1 + 13 files changed, 222 insertions(+), 68 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index 6a95eadc92139..a8a54fade7601 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -28,6 +28,7 @@ import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; @@ -389,6 +390,9 @@ public ActionRequestValidationException validate() { if (source.aggregations() != null) { validationException = source.aggregations().validate(validationException); } + if (source.rescores() != null) { + validationException = validateRescores(validationException); + } if (source.rankBuilder() != null) { int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size(); if (size == 0) { @@ -486,6 +490,48 @@ public ActionRequestValidationException validate() { return validationException; } + public ActionRequestValidationException validateRescores(ActionRequestValidationException validationException) { + RescorerBuilder nonCombinableRescorer = null; + + if (source.rescores() == null) { + return validationException; + } + + int paginationWindowSize = source.from() + source.size(); + + for (RescorerBuilder currentRescorer: source.rescores()) { + if (nonCombinableRescorer != null && nonCombinableRescorer.windowSize() < currentRescorer.windowSize()) { + validationException = addValidationError( + "unable to add a rescorer with [window_size: " + + currentRescorer.windowSize() + + "] because a rescorer of type [" + + nonCombinableRescorer.getWriteableName() + + "] with a smaller [window_size: " + + nonCombinableRescorer.windowSize() + + "] has been added before", + validationException + ); + } + + if (currentRescorer.canCombineScores() == false) { + if (currentRescorer.windowSize() < paginationWindowSize) { + validationException = addValidationError( + "rescorer [window_size] is too small and should be at least the value of [from + size: " + + paginationWindowSize + + "] but was [" + + currentRescorer.windowSize() + +"]", + validationException + ); + } + + nonCombinableRescorer = currentRescorer; + } + } + + return validationException; + } + /** * Returns the alias of the cluster that this search request is being executed on. A non-null value indicates that this search request * is being executed as part of a locally reduced cross-cluster search request. The cluster alias is used to prefix index names diff --git a/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java b/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java index c873717fe55e7..41214ffd08661 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java @@ -155,8 +155,8 @@ public static class QueryRescoreContext extends RescoreContext { private float rescoreQueryWeight = 1.0f; private QueryRescoreMode scoreMode; - public QueryRescoreContext(int windowSize) { - super(windowSize, QueryRescorer.INSTANCE); + public QueryRescoreContext(int windowSize, boolean canCombineScores) { + super(windowSize, QueryRescorer.INSTANCE, canCombineScores); this.scoreMode = QueryRescoreMode.Total; } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorerBuilder.java index 24ba423315687..0b95b938bd639 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorerBuilder.java @@ -167,7 +167,7 @@ public static QueryRescorerBuilder fromXContent(XContentParser parser) throws IO @Override public QueryRescoreContext innerBuildContext(int windowSize, SearchExecutionContext context) throws IOException { - QueryRescoreContext queryRescoreContext = new QueryRescoreContext(windowSize); + QueryRescoreContext queryRescoreContext = new QueryRescoreContext(windowSize, canCombineScores()); // query is rewritten at this point already queryRescoreContext.setQuery(context.toQuery(queryBuilder)); queryRescoreContext.setQueryWeight(this.queryWeight); diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java b/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java index 06df6d654d17c..72a0f358a68bb 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java @@ -22,15 +22,18 @@ public class RescoreContext { private final int windowSize; private final Rescorer rescorer; + private final boolean canCombineScores; private Set rescoredDocs; // doc Ids for which rescoring was applied /** * Build the context. * @param rescorer the rescorer actually performing the rescore. + * @param canCombineScores Indicates if the rescorer score can be combined with other scores. */ - public RescoreContext(int windowSize, Rescorer rescorer) { + public RescoreContext(int windowSize, Rescorer rescorer, boolean canCombineScores) { this.windowSize = windowSize; this.rescorer = rescorer; + this.canCombineScores = canCombineScores; } /** @@ -65,4 +68,8 @@ public Set getRescoredDocs() { public List getParsedQueries() { return Collections.emptyList(); } + + public boolean canCombineScores() { + return canCombineScores; + } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index 697aa6099ca97..34c71f17079b2 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -44,6 +44,15 @@ public static void execute(SearchContext context) { } try { for (RescoreContext ctx : context.rescore()) { + if (ctx.canCombineScores() == false) { + /** + * When it is impossible to combine scores from the first-pass query and the rescorer, we truncate the top docs to + * the window size before executing the rescorer. + * + * @see RescorerBuilder#canCombineScores() for more details. + */ + topDocs = topN(topDocs, ctx.getWindowSize()); + } topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx); // It is the responsibility of the rescorer to sort the resulted top docs, // here we only assert that this condition is met. @@ -105,4 +114,17 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) { } return true; } + + /** Returns a new {@link TopDocs} with the topN from the incoming one, or the same TopDocs if the number of hits is already <= + * topN. */ + private static TopDocs topN(TopDocs in, int topN) { + if (in.scoreDocs.length < topN) { + return in; + } + + ScoreDoc[] subset = new ScoreDoc[topN]; + System.arraycopy(in.scoreDocs, 0, subset, 0, topN); + + return new TopDocs(in.totalHits, subset); + } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java index 4c42daba22b7a..2d9c31a30ac7a 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java @@ -68,6 +68,24 @@ public Integer windowSize() { return windowSize; } + /** + * In some situations (e.g., LTR rescorer), it is impossible to combine scores issued by the rescoring phase those from + * the first-pass query (or previous rescorers) because they are not comparable each others. + * + * In this case: + * + * - we need to ensure that the full topDocs is rescored + * - the topDocs is truncated to the window size before executing the rescorer + * - we prevent subsequent rescorers with a bigger window size + * - we check the window size for the rescorer is at least equals to from + size + * - window size is a required parameter for the rescorer + * + * @return whether it is possible to combine scores issued by the rescoring phase with original scores or not. + */ + public boolean canCombineScores() { + return true; + } + public static RescorerBuilder parseFromXContent(XContentParser parser, Consumer rescorerNameConsumer) throws IOException { String fieldName = null; RescorerBuilder rescorer = null; @@ -100,7 +118,7 @@ public static RescorerBuilder parseFromXContent(XContentParser parser, Consum if (windowSize != null) { rescorer.windowSize(windowSize.intValue()); - } else if (rescorer.isWindowSizeRequired()) { + } else if (rescorer.canCombineScores() == false) { throw new ParsingException(parser.getTokenLocation(), "window_size is required for rescorer of type [" + rescorerType + "]"); } @@ -120,24 +138,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException; - /** - * Indicate if the window_size is a required parameter for the rescorer. - */ - protected boolean isWindowSizeRequired() { - return false; - } - /** * Build the {@linkplain RescoreContext} that will be used to actually * execute the rescore against a particular shard. */ public final RescoreContext buildContext(SearchExecutionContext context) throws IOException { - if (isWindowSizeRequired()) { + if (canCombineScores() == false) { assert windowSize != null; } int finalWindowSize = windowSize == null ? DEFAULT_WINDOW_SIZE : windowSize; - RescoreContext rescoreContext = innerBuildContext(finalWindowSize, context); - return rescoreContext; + + return innerBuildContext(finalWindowSize, context); } /** diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index d8c7d3e134571..7ed44e23d2d38 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.rank.TestRankBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.term.TermSuggestionBuilder; @@ -47,6 +48,8 @@ import static java.util.Collections.emptyMap; import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; public class SearchRequestTests extends AbstractSearchTestCase { @@ -659,4 +662,61 @@ public void testForceSyntheticUnsupported() { Exception e = expectThrows(IllegalArgumentException.class, () -> request.writeTo(out)); assertEquals(e.getMessage(), "force_synthetic_source is not supported before 8.4.0"); } + + public void testRescoreChainValidation() { + { + SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 10000))) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 10000))) + .addRescorer(createRescorerMock(false, 50)) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 50))) + .addRescorer(createRescorerMock(false, 50)) + .addRescorer(createRescorerMock(false, 20)) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 20))) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 20))); + + SearchRequest searchRequest = new SearchRequest().source(source); + ActionRequestValidationException validationErrors = searchRequest.validate(); + assertNull(validationErrors); + } + + { + RescorerBuilder rescorer = createRescorerMock(false, randomIntBetween(2, 19)); + SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10).addRescorer(rescorer); + + SearchRequest searchRequest = new SearchRequest().source(source); + ActionRequestValidationException validationErrors = searchRequest.validate(); + assertThat( + validationErrors.validationErrors().get(0), + equalTo( + "rescorer [window_size] is too small and should be at least the value of [from + size: 20] but was [" + rescorer.windowSize() + "]" + ) + ); + } + + { + SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 10000))) + .addRescorer(createRescorerMock(true, randomIntBetween(2, 10000))) + .addRescorer(createRescorerMock(false, 50)) + .addRescorer(createRescorerMock(randomBoolean(), 60)); + + SearchRequest searchRequest = new SearchRequest().source(source); + ActionRequestValidationException validationErrors = searchRequest.validate(); + assertThat( + validationErrors.validationErrors().get(0), + equalTo( + "unable to add a rescorer with [window_size: 60] because a rescorer of type [not_combinable] with a smaller [window_size: 50] has been added before" + ) + ); + } + } + + private RescorerBuilder createRescorerMock(boolean canCombineScore, int windowSize) { + RescorerBuilder rescorer = mock(RescorerBuilder.class); + doReturn(canCombineScore).when(rescorer).canCombineScores(); + doReturn(windowSize).when(rescorer).windowSize(); + doReturn("not_combinable").when(rescorer).getWriteableName(); + return rescorer; + } } diff --git a/server/src/test/java/org/elasticsearch/search/rescore/QueryRescorerBuilderTests.java b/server/src/test/java/org/elasticsearch/search/rescore/QueryRescorerBuilderTests.java index 3193655b02747..a91c69a821dd0 100644 --- a/server/src/test/java/org/elasticsearch/search/rescore/QueryRescorerBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/rescore/QueryRescorerBuilderTests.java @@ -183,6 +183,7 @@ public MappedFieldType getFieldType(String name) { assertEquals(rescoreBuilder.getQueryWeight(), rescoreContext.queryWeight(), Float.MIN_VALUE); assertEquals(rescoreBuilder.getRescoreQueryWeight(), rescoreContext.rescoreQueryWeight(), Float.MIN_VALUE); assertEquals(rescoreBuilder.getScoreMode(), rescoreContext.scoreMode()); + assertTrue(rescoreContext.canCombineScores()); } } diff --git a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java index f6aca48a3f493..b21e04b721e25 100644 --- a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java +++ b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.List; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; public class LearningToRankRescorerIT extends InferenceTestCase { @@ -241,33 +242,59 @@ public void testLearningToRankRescoreSmallWindow() throws Exception { "learning_to_rank": { "model_id": "ltr-model" } } }"""); - assertThrows( - "Rescore window is too small and should be at least the value of from + size but was [2]", - ResponseException.class, - () -> client().performRequest(request) + + Exception e = assertThrows(ResponseException.class, () -> client().performRequest(request)); + assertThat( + e.getMessage(), + containsString( "rescorer [window_size] is too small and should be at least the value of [from + size: 4] but was [2]") ); } + + public void testLearningToRankRescorerWithChainedRescorers() throws IOException { - Request request = new Request("GET", "store/_search?size=5"); - request.setJsonEntity(""" + + String queryTemplate = """ { - "rescore": [ - { - "window_size": 15, - "query": { "rescore_query" : { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 4" } } } } - }, - { - "window_size": 25, - "learning_to_rank": { "model_id": "ltr-model" } - }, - { - "window_size": 35, - "query": { "rescore_query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 20"} } } } - } - ] - }"""); - assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 37.0, 29.0, 29.0)); + "rescore": [ + { + "window_size": %d, + "query": { "rescore_query" : { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 4" } } } } + }, + { + "window_size": 4, + "learning_to_rank": { "model_id": "ltr-model" } + }, + { + "window_size": %d, + "query": { "rescore_query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 20"} } } } + } + ] + }"""; + + + { + Request request = new Request("GET", "store/_search?size=4"); + request.setJsonEntity(String.format(queryTemplate, randomIntBetween(2, 10000), randomIntBetween(2, 4))); + assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 37.0, 29.0)); + } + + { + int lastRescorerWindowSize = randomIntBetween(5, 10000); + Request request = new Request("GET", "store/_search?size=4"); + request.setJsonEntity(String.format(queryTemplate, randomIntBetween(2, 10000), lastRescorerWindowSize)); + + Exception e = assertThrows(ResponseException.class, () -> client().performRequest(request)); + assertThat( + e.getMessage(), + containsString( + "unable to add a rescorer with [window_size: " + + lastRescorerWindowSize + + "] because a rescorer of type [learning_to_rank]" + +" with a smaller [window_size: 4] has been added before" + ) + ); + } } private void indexData(String data) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java index 4e3fa3addaf30..ca7ff9a35a9c3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java @@ -56,22 +56,12 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?"); } - if (rescoreContext.getWindowSize() < topDocs.scoreDocs.length) { - throw new IllegalArgumentException( - "Rescore window is too small and should be at least the value of from + size but was [" - + rescoreContext.getWindowSize() - + "]" - ); - } - LocalModel definition = ltrRescoreContext.regressionModelDefinition; - // First take top slice of incoming docs, to be rescored: - TopDocs topNFirstPass = topN(topDocs, rescoreContext.getWindowSize()); // Save doc IDs for which rescoring was applied to be used in score explanation - Set topNDocIDs = Arrays.stream(topNFirstPass.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(toUnmodifiableSet()); - rescoreContext.setRescoredDocs(topNDocIDs); - ScoreDoc[] hitsToRescore = topNFirstPass.scoreDocs; + Set topDocIDs = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(toUnmodifiableSet()); + rescoreContext.setRescoredDocs(topDocIDs); + ScoreDoc[] hitsToRescore = topDocs.scoreDocs; Arrays.sort(hitsToRescore, Comparator.comparingInt(a -> a.doc)); int hitUpto = 0; int readerUpto = -1; @@ -81,7 +71,7 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r LeafReaderContext currentSegment = null; boolean changedSegment = true; List featureExtractors = ltrRescoreContext.buildFeatureExtractors(searcher); - List> docFeatures = new ArrayList<>(topNDocIDs.size()); + List> docFeatures = new ArrayList<>(topDocIDs.size()); int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum(); while (hitUpto < hitsToRescore.length) { final ScoreDoc hit = hitsToRescore[hitUpto]; @@ -138,17 +128,4 @@ public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreCon // TODO: Call infer again but with individual feature importance values and explaining the model (which features are used, etc.) return null; } - - /** Returns a new {@link TopDocs} with the topN from the incoming one, or the same TopDocs if the number of hits is already <= - * topN. */ - private static TopDocs topN(TopDocs in, int topN) { - if (in.scoreDocs.length < topN) { - return in; - } - - ScoreDoc[] subset = new ScoreDoc[topN]; - System.arraycopy(in.scoreDocs, 0, subset, 0, topN); - - return new TopDocs(in.totalHits, subset); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java index 9aa0e75b944fe..76cc68eaecbb4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java @@ -247,7 +247,7 @@ private RescorerBuilder doSearchRewrite(QueryRewr @Override protected LearningToRankRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) { rescoreOccurred = true; - return new LearningToRankRescorerContext(windowSize, LearningToRankRescorer.INSTANCE, learningToRankConfig, localModel, context); + return new LearningToRankRescorerContext(windowSize, LearningToRankRescorer.INSTANCE, canCombineScores(), learningToRankConfig, localModel, context); } @Override @@ -262,8 +262,8 @@ public TransportVersion getMinimalSupportedVersion() { } @Override - protected boolean isWindowSizeRequired() { - return true; + public boolean canCombineScores() { + return false; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java index b1df3a2da7c42..119d2ca02eb4e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java @@ -33,6 +33,7 @@ public class LearningToRankRescorerContext extends RescoreContext { /** * @param windowSize how many documents to rescore * @param rescorer The rescorer to apply + * @param canCombineScores Indicates if the rescorer score can be combined with other scores * @param learningToRankConfig The inference config containing updated and rewritten parameters * @param regressionModelDefinition The local model inference definition, may be null during certain search phases. * @param executionContext The local shard search context @@ -40,11 +41,12 @@ public class LearningToRankRescorerContext extends RescoreContext { public LearningToRankRescorerContext( int windowSize, Rescorer rescorer, + boolean canCombineScores, LearningToRankConfig learningToRankConfig, LocalModel regressionModelDefinition, SearchExecutionContext executionContext ) { - super(windowSize, rescorer); + super(windowSize, rescorer, canCombineScores); this.executionContext = executionContext; this.regressionModelDefinition = regressionModelDefinition; this.learningToRankConfig = learningToRankConfig; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java index 3bfe8aa390d8b..3128878bbeaf4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java @@ -216,6 +216,7 @@ public void testBuildContext() throws Exception { featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(), containsInAnyOrder("feature_1", "feature_2", DOUBLE_FIELD_NAME, INT_FIELD_NAME) ); + assertFalse(rescoreContext.canCombineScores()); } private LearningToRankRescorerBuilder rewriteAndFetch(