Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed May 30, 2024
1 parent 9250063 commit 3851504
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ public ActionRequestValidationException validate() {
if (source.aggregations() != null) {
validationException = source.aggregations().validate(validationException);
}
if (source.rescores() != null) {
validationException = validateRescores(validationException);
}
if (source.rankBuilder() != null) {
int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size();
if (size == 0) {
Expand Down Expand Up @@ -444,6 +441,11 @@ public ActionRequestValidationException validate() {
validationException = addValidationError("[rank] requires [explain] is [false]", validationException);
}
}
if (source.rescores() != null) {
for (@SuppressWarnings("rawtypes") RescorerBuilder rescoreBuilder : source.rescores()) {
validationException = rescoreBuilder.validate(this, validationException);
}
}
}
if (pointInTimeBuilder() != null) {
if (scroll) {
Expand Down Expand Up @@ -490,48 +492,6 @@ public ActionRequestValidationException validate() {
return validationException;
}

public ActionRequestValidationException validateRescores(ActionRequestValidationException validationException) {
RescorerBuilder<?> nonCombinableRescorer = null;

if (source.rescores() == null) {
return validationException;
}

int paginationWindowSize = source.from() + source.size();

for (RescorerBuilder<?> currentRescorer : source.rescores()) {
if (nonCombinableRescorer != null && nonCombinableRescorer.windowSize() < currentRescorer.windowSize()) {
validationException = addValidationError(
"unable to add a rescorer with [window_size: "
+ currentRescorer.windowSize()
+ "] because a rescorer of type ["
+ nonCombinableRescorer.getWriteableName()
+ "] with a smaller [window_size: "
+ nonCombinableRescorer.windowSize()
+ "] has been added before",
validationException
);
}

if (currentRescorer.canCombineScores() == false) {
if (currentRescorer.windowSize() < paginationWindowSize) {
validationException = addValidationError(
"rescorer [window_size] is too small and should be at least the value of [from + size: "
+ paginationWindowSize
+ "] but was ["
+ currentRescorer.windowSize()
+ "]",
validationException
);
}

nonCombinableRescorer = currentRescorer;
}
}

return validationException;
}

/**
* Returns the alias of the cluster that this search request is being executed on. A non-null value indicates that this search request
* is being executed as part of a locally reduced cross-cluster search request. The cluster alias is used to prefix index names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ public static class QueryRescoreContext extends RescoreContext {
private float rescoreQueryWeight = 1.0f;
private QueryRescoreMode scoreMode;

public QueryRescoreContext(int windowSize, boolean canCombineScores) {
super(windowSize, QueryRescorer.INSTANCE, canCombineScores);
public QueryRescoreContext(int windowSize) {
super(windowSize, QueryRescorer.INSTANCE);
this.scoreMode = QueryRescoreMode.Total;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public static QueryRescorerBuilder fromXContent(XContentParser parser) throws IO

@Override
public QueryRescoreContext innerBuildContext(int windowSize, SearchExecutionContext context) throws IOException {
QueryRescoreContext queryRescoreContext = new QueryRescoreContext(windowSize, canCombineScores());
QueryRescoreContext queryRescoreContext = new QueryRescoreContext(windowSize);
// query is rewritten at this point already
queryRescoreContext.setQuery(context.toQuery(queryBuilder));
queryRescoreContext.setQueryWeight(this.queryWeight);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@
public class RescoreContext {
private final int windowSize;
private final Rescorer rescorer;
private final boolean canCombineScores;
private Set<Integer> rescoredDocs; // doc Ids for which rescoring was applied

/**
* Build the context.
* @param rescorer the rescorer actually performing the rescore.
* @param canCombineScores Indicates if the rescorer score can be combined with other scores.
*/
public RescoreContext(int windowSize, Rescorer rescorer, boolean canCombineScores) {
public RescoreContext(int windowSize, Rescorer rescorer) {
this.windowSize = windowSize;
this.rescorer = rescorer;
this.canCombineScores = canCombineScores;
}

/**
Expand Down Expand Up @@ -68,8 +65,4 @@ public Set<Integer> getRescoredDocs() {
public List<ParsedQuery> getParsedQueries() {
return Collections.emptyList();
}

public boolean canCombineScores() {
return canCombineScores;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,6 @@ public static void execute(SearchContext context) {
}
try {
for (RescoreContext ctx : context.rescore()) {
if (ctx.canCombineScores() == false) {
/**
* When it is impossible to combine scores from the first-pass query and the rescorer, we truncate the top docs to
* the window size before executing the rescorer.
*
* @see RescorerBuilder#canCombineScores() for more details.
*/
topDocs = topN(topDocs, ctx.getWindowSize());
}
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
// It is the responsibility of the rescorer to sort the resulted top docs,
// here we only assert that this condition is met.
Expand Down Expand Up @@ -114,17 +105,4 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) {
}
return true;
}

/** Returns a new {@link TopDocs} with the topN from the incoming one, or the same TopDocs if the number of hits is already &lt;=
* topN. */
private static TopDocs topN(TopDocs in, int topN) {
if (in.scoreDocs.length < topN) {
return in;
}

ScoreDoc[] subset = new ScoreDoc[topN];
System.arraycopy(in.scoreDocs, 0, subset, 0, topN);

return new TopDocs(in.totalHits, subset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.elasticsearch.search.rescore;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -68,24 +70,6 @@ public Integer windowSize() {
return windowSize;
}

/**
* In some situations (e.g., LTR rescorer), it is impossible to combine scores issued by the rescoring phase those from
* the first-pass query (or previous rescorers) because they are not comparable with each other.
*
* In this case:
*
* - we need to ensure that the full topDocs is rescored
* - the topDocs is truncated to the window size before executing the rescorer
* - we prevent subsequent rescorers with a bigger window size
* - we check the window size for the rescorer is at least equals to from + size
* - window size is a required parameter for the rescorer
*
* @return whether it is possible to combine scores issued by the rescoring phase with original scores or not.
*/
public boolean canCombineScores() {
return true;
}

public static RescorerBuilder<?> parseFromXContent(XContentParser parser, Consumer<String> rescorerNameConsumer) throws IOException {
String fieldName = null;
RescorerBuilder<?> rescorer = null;
Expand Down Expand Up @@ -118,7 +102,7 @@ public static RescorerBuilder<?> parseFromXContent(XContentParser parser, Consum

if (windowSize != null) {
rescorer.windowSize(windowSize.intValue());
} else if (rescorer.canCombineScores() == false) {
} else if (rescorer.isWindowSizeRequired()) {
throw new ParsingException(parser.getTokenLocation(), "window_size is required for rescorer of type [" + rescorerType + "]");
}

Expand All @@ -136,16 +120,24 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) {
return validationException;
}

protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException;

/**
* Indicate if the window_size is a required parameter for the rescorer.
*/
protected boolean isWindowSizeRequired() {
return false;
}

/**
* Build the {@linkplain RescoreContext} that will be used to actually
* execute the rescore against a particular shard.
*/
public final RescoreContext buildContext(SearchExecutionContext context) throws IOException {
if (canCombineScores() == false) {
assert windowSize != null;
}
int finalWindowSize = windowSize == null ? DEFAULT_WINDOW_SIZE : windowSize;

return innerBuildContext(finalWindowSize, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.rank.TestRankBuilder;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.slice.SliceBuilder;
import org.elasticsearch.search.suggest.SuggestBuilder;
import org.elasticsearch.search.suggest.term.TermSuggestionBuilder;
Expand All @@ -48,8 +47,6 @@
import static java.util.Collections.emptyMap;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

public class SearchRequestTests extends AbstractSearchTestCase {

Expand Down Expand Up @@ -662,66 +659,4 @@ public void testForceSyntheticUnsupported() {
Exception e = expectThrows(IllegalArgumentException.class, () -> request.writeTo(out));
assertEquals(e.getMessage(), "force_synthetic_source is not supported before 8.4.0");
}

public void testRescoreChainValidation() {
{
SearchSourceBuilder source = new SearchSourceBuilder().from(10)
.size(10)
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(false, 50))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 50)))
.addRescorer(createRescorerMock(false, 50))
.addRescorer(createRescorerMock(false, 20))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 20)))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 20)));

SearchRequest searchRequest = new SearchRequest().source(source);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertNull(validationErrors);
}

{
RescorerBuilder<?> rescorer = createRescorerMock(false, randomIntBetween(2, 19));
SearchSourceBuilder source = new SearchSourceBuilder().from(10).size(10).addRescorer(rescorer);

SearchRequest searchRequest = new SearchRequest().source(source);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertThat(
validationErrors.validationErrors().get(0),
equalTo(
"rescorer [window_size] is too small and should be at least the value of [from + size: 20] but was ["
+ rescorer.windowSize()
+ "]"
)
);
}

{
SearchSourceBuilder source = new SearchSourceBuilder().from(10)
.size(10)
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(true, randomIntBetween(2, 10000)))
.addRescorer(createRescorerMock(false, 50))
.addRescorer(createRescorerMock(randomBoolean(), 60));

SearchRequest searchRequest = new SearchRequest().source(source);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertThat(
validationErrors.validationErrors().get(0),
equalTo(
"unable to add a rescorer with [window_size: 60] because a rescorer of type [not_combinable] "
+ "with a smaller [window_size: 50] has been added before"
)
);
}
}

private RescorerBuilder<?> createRescorerMock(boolean canCombineScore, int windowSize) {
RescorerBuilder<?> rescorer = mock(RescorerBuilder.class);
doReturn(canCombineScore).when(rescorer).canCombineScores();
doReturn(windowSize).when(rescorer).windowSize();
doReturn("not_combinable").when(rescorer).getWriteableName();
return rescorer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ public MappedFieldType getFieldType(String name) {
assertEquals(rescoreBuilder.getQueryWeight(), rescoreContext.queryWeight(), Float.MIN_VALUE);
assertEquals(rescoreBuilder.getRescoreQueryWeight(), rescoreContext.rescoreQueryWeight(), Float.MIN_VALUE);
assertEquals(rescoreBuilder.getScoreMode(), rescoreContext.scoreMode());
assertTrue(rescoreContext.canCombineScores());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r

LocalModel definition = ltrRescoreContext.regressionModelDefinition;

// Because scores of the first-paass query and the LTR model are not comparable, there is no way to combine the results.
// We will truncate the {@link TopDocs} to the window size so rescoring will be done on the full topDocs.
topDocs = topN(topDocs, rescoreContext.getWindowSize());


// Save doc IDs for which rescoring was applied to be used in score explanation
Set<Integer> topDocIDs = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(toUnmodifiableSet());
rescoreContext.setRescoredDocs(topDocIDs);
Expand Down Expand Up @@ -128,4 +133,18 @@ public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreCon
// TODO: Call infer again but with individual feature importance values and explaining the model (which features are used, etc.)
return null;
}


/** Returns a new {@link TopDocs} with the topN from the incoming one, or the same TopDocs if the number of hits is already &lt;=
* topN. */
private static TopDocs topN(TopDocs in, int topN) {
if (in.scoreDocs.length < topN) {
return in;
}

ScoreDoc[] subset = new ScoreDoc[topN];
System.arraycopy(in.scoreDocs, 0, subset, 0, topN);

return new TopDocs(in.totalHits, subset);
}
}
Loading

0 comments on commit 3851504

Please sign in to comment.