Skip to content

Commit

Permalink
Better handling of multiple rescorers clauses with LTR.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed May 27, 2024
1 parent a76a0b3 commit 85c0d04
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortBuilder;
Expand Down Expand Up @@ -389,6 +390,9 @@ public ActionRequestValidationException validate() {
if (source.aggregations() != null) {
validationException = source.aggregations().validate(validationException);
}
if (source.rescores() != null) {
validationException = validateRescores(validationException);
}
if (source.rankBuilder() != null) {
int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size();
if (size == 0) {
Expand Down Expand Up @@ -486,6 +490,48 @@ public ActionRequestValidationException validate() {
return validationException;
}

public ActionRequestValidationException validateRescores(ActionRequestValidationException validationException) {
RescorerBuilder<?> nonCombinableRescorer = null;

if (source.rescores() == null) {
return validationException;
}

int paginationWindowSize = source.from() + source.size();

for (RescorerBuilder<?> currentRescorer: source.rescores()) {
if (nonCombinableRescorer != null && nonCombinableRescorer.windowSize() < currentRescorer.windowSize()) {
validationException = addValidationError(
"unable to add a rescorer with [window_size: "
+ currentRescorer.windowSize()
+ "] because a rescorer of type ["
+ nonCombinableRescorer.getWriteableName()
+ "] with a smaller [window_size: "
+ nonCombinableRescorer.windowSize()
+ "] has been added before",
validationException
);
}

if (currentRescorer.canCombineScores() == false) {
if (currentRescorer.windowSize() < paginationWindowSize) {
validationException = addValidationError(
"rescorer [window_size] is too small and should be at least the value of [from + size: "
+ paginationWindowSize
+ "] but was ["
+ currentRescorer.windowSize()
+"]",
validationException
);
}

nonCombinableRescorer = currentRescorer;
}
}

return validationException;
}

/**
* Returns the alias of the cluster that this search request is being executed on. A non-null value indicates that this search request
* is being executed as part of a locally reduced cross-cluster search request. The cluster alias is used to prefix index names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ public static class QueryRescoreContext extends RescoreContext {
private float rescoreQueryWeight = 1.0f;
private QueryRescoreMode scoreMode;

public QueryRescoreContext(int windowSize) {
super(windowSize, QueryRescorer.INSTANCE);
public QueryRescoreContext(int windowSize, boolean canCombineScores) {
super(windowSize, QueryRescorer.INSTANCE, canCombineScores);
this.scoreMode = QueryRescoreMode.Total;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public static QueryRescorerBuilder fromXContent(XContentParser parser) throws IO

@Override
public QueryRescoreContext innerBuildContext(int windowSize, SearchExecutionContext context) throws IOException {
QueryRescoreContext queryRescoreContext = new QueryRescoreContext(windowSize);
QueryRescoreContext queryRescoreContext = new QueryRescoreContext(windowSize, canCombineScores());
// query is rewritten at this point already
queryRescoreContext.setQuery(context.toQuery(queryBuilder));
queryRescoreContext.setQueryWeight(this.queryWeight);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@
public class RescoreContext {
private final int windowSize;
private final Rescorer rescorer;
private final boolean canCombineScores;
private Set<Integer> rescoredDocs; // doc Ids for which rescoring was applied

/**
* Build the context.
* @param rescorer the rescorer actually performing the rescore.
* @param canCombineScores Indicates if the rescorer score can be combined with other scores.
*/
public RescoreContext(int windowSize, Rescorer rescorer) {
public RescoreContext(int windowSize, Rescorer rescorer, boolean canCombineScores) {
this.windowSize = windowSize;
this.rescorer = rescorer;
this.canCombineScores = canCombineScores;
}

/**
Expand Down Expand Up @@ -65,4 +68,8 @@ public Set<Integer> getRescoredDocs() {
public List<ParsedQuery> getParsedQueries() {
return Collections.emptyList();
}

public boolean canCombineScores() {
return canCombineScores;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ public static void execute(SearchContext context) {
}
try {
for (RescoreContext ctx : context.rescore()) {
if (ctx.canCombineScores() == false) {
/**
* When it is impossible to combine scores from the first-pass query and the rescorer, we truncate the top docs to
* the window size before executing the rescorer.
*
* @see RescorerBuilder#canCombineScores() for more details.
*/
topDocs = topN(topDocs, ctx.getWindowSize());
}
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
// It is the responsibility of the rescorer to sort the resulted top docs,
// here we only assert that this condition is met.
Expand Down Expand Up @@ -105,4 +114,17 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) {
}
return true;
}

/** Returns a new {@link TopDocs} with the topN from the incoming one, or the same TopDocs if the number of hits is already &lt;=
* topN. */
private static TopDocs topN(TopDocs in, int topN) {
if (in.scoreDocs.length < topN) {
return in;
}

ScoreDoc[] subset = new ScoreDoc[topN];
System.arraycopy(in.scoreDocs, 0, subset, 0, topN);

return new TopDocs(in.totalHits, subset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ public Integer windowSize() {
return windowSize;
}

/**
* In some situations (e.g., LTR rescorer), it is impossible to combine scores issued by the rescoring phase those from
* the first-pass query (or previous rescorers) because they are not comparable each others.
*
* In this case:
*
* - we need to ensure that the full topDocs is rescored
* - the topDocs is truncated to the window size before executing the rescorer
* - we prevent subsequent rescorers with a bigger window size
* - we check the window size for the rescorer is at least equals to from + size
* - window size is a required parameter for the rescorer
*
* @return whether it is possible to combine scores issued by the rescoring phase with original scores or not.
*/
public boolean canCombineScores() {
return true;
}

public static RescorerBuilder<?> parseFromXContent(XContentParser parser, Consumer<String> rescorerNameConsumer) throws IOException {
String fieldName = null;
RescorerBuilder<?> rescorer = null;
Expand Down Expand Up @@ -100,7 +118,7 @@ public static RescorerBuilder<?> parseFromXContent(XContentParser parser, Consum

if (windowSize != null) {
rescorer.windowSize(windowSize.intValue());
} else if (rescorer.isWindowSizeRequired()) {
} else if (rescorer.canCombineScores() == false) {
throw new ParsingException(parser.getTokenLocation(), "window_size is required for rescorer of type [" + rescorerType + "]");
}

Expand All @@ -120,24 +138,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException;

/**
* Indicate if the window_size is a required parameter for the rescorer.
*/
protected boolean isWindowSizeRequired() {
return false;
}

/**
* Build the {@linkplain RescoreContext} that will be used to actually
* execute the rescore against a particular shard.
*/
public final RescoreContext buildContext(SearchExecutionContext context) throws IOException {
if (isWindowSizeRequired()) {
if (canCombineScores() == false) {
assert windowSize != null;
}
int finalWindowSize = windowSize == null ? DEFAULT_WINDOW_SIZE : windowSize;
RescoreContext rescoreContext = innerBuildContext(finalWindowSize, context);
return rescoreContext;

return innerBuildContext(finalWindowSize, context);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.rank.TestRankBuilder;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.slice.SliceBuilder;
import org.elasticsearch.search.suggest.SuggestBuilder;
import org.elasticsearch.search.suggest.term.TermSuggestionBuilder;
Expand All @@ -47,6 +48,8 @@
import static java.util.Collections.emptyMap;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

public class SearchRequestTests extends AbstractSearchTestCase {

Expand Down Expand Up @@ -659,4 +662,61 @@ public void testForceSyntheticUnsupported() {
Exception e = expectThrows(IllegalArgumentException.class, () -> request.writeTo(out));
assertEquals(e.getMessage(), "force_synthetic_source is not supported before 8.4.0");
}

public void testRescoreChainValidation() {
{
SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10)
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(false, 50))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 50)))
.addRescorer(createRescorerMock(false, 50))
.addRescorer(createRescorerMock(false, 20))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 20)))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 20)));

SearchRequest searchRequest = new SearchRequest().source(source);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertNull(validationErrors);
}

{
RescorerBuilder<?> rescorer = createRescorerMock(false, randomIntBetween(2, 19));
SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10).addRescorer(rescorer);

SearchRequest searchRequest = new SearchRequest().source(source);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertThat(
validationErrors.validationErrors().get(0),
equalTo(
"rescorer [window_size] is too small and should be at least the value of [from + size: 20] but was [" + rescorer.windowSize() + "]"
)
);
}

{
SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10)
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(false, 50))
.addRescorer(createRescorerMock(randomBoolean(), 60));

SearchRequest searchRequest = new SearchRequest().source(source);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertThat(
validationErrors.validationErrors().get(0),
equalTo(
"unable to add a rescorer with [window_size: 60] because a rescorer of type [not_combinable] with a smaller [window_size: 50] has been added before"
)
);
}
}

private RescorerBuilder<?> createRescorerMock(boolean canCombineScore, int windowSize) {
RescorerBuilder<?> rescorer = mock(RescorerBuilder.class);
doReturn(canCombineScore).when(rescorer).canCombineScores();
doReturn(windowSize).when(rescorer).windowSize();
doReturn("not_combinable").when(rescorer).getWriteableName();
return rescorer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ public MappedFieldType getFieldType(String name) {
assertEquals(rescoreBuilder.getQueryWeight(), rescoreContext.queryWeight(), Float.MIN_VALUE);
assertEquals(rescoreBuilder.getRescoreQueryWeight(), rescoreContext.rescoreQueryWeight(), Float.MIN_VALUE);
assertEquals(rescoreBuilder.getScoreMode(), rescoreContext.scoreMode());
assertTrue(rescoreContext.canCombineScores());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.IOException;
import java.util.List;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

public class LearningToRankRescorerIT extends InferenceTestCase {
Expand Down Expand Up @@ -241,33 +242,59 @@ public void testLearningToRankRescoreSmallWindow() throws Exception {
"learning_to_rank": { "model_id": "ltr-model" }
}
}""");
assertThrows(
"Rescore window is too small and should be at least the value of from + size but was [2]",
ResponseException.class,
() -> client().performRequest(request)

Exception e = assertThrows(ResponseException.class, () -> client().performRequest(request));
assertThat(
e.getMessage(),
containsString( "rescorer [window_size] is too small and should be at least the value of [from + size: 4] but was [2]")
);
}



public void testLearningToRankRescorerWithChainedRescorers() throws IOException {
Request request = new Request("GET", "store/_search?size=5");
request.setJsonEntity("""

String queryTemplate = """
{
"rescore": [
{
"window_size": 15,
"query": { "rescore_query" : { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 4" } } } }
},
{
"window_size": 25,
"learning_to_rank": { "model_id": "ltr-model" }
},
{
"window_size": 35,
"query": { "rescore_query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 20"} } } }
}
]
}""");
assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 37.0, 29.0, 29.0));
"rescore": [
{
"window_size": %d,
"query": { "rescore_query" : { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 4" } } } }
},
{
"window_size": 4,
"learning_to_rank": { "model_id": "ltr-model" }
},
{
"window_size": %d,
"query": { "rescore_query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 20"} } } }
}
]
}""";


{
Request request = new Request("GET", "store/_search?size=4");
request.setJsonEntity(String.format(queryTemplate, randomIntBetween(2, 10000), randomIntBetween(2, 4)));
assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 37.0, 29.0));
}

{
int lastRescorerWindowSize = randomIntBetween(5, 10000);
Request request = new Request("GET", "store/_search?size=4");
request.setJsonEntity(String.format(queryTemplate, randomIntBetween(2, 10000), lastRescorerWindowSize));

Exception e = assertThrows(ResponseException.class, () -> client().performRequest(request));
assertThat(
e.getMessage(),
containsString(
"unable to add a rescorer with [window_size: "
+ lastRescorerWindowSize
+ "] because a rescorer of type [learning_to_rank]"
+" with a smaller [window_size: 4] has been added before"
)
);
}
}

private void indexData(String data) throws IOException {
Expand Down
Loading

0 comments on commit 85c0d04

Please sign in to comment.