From 0aa194ab5eedbceede26c346d547ae56f2c25631 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20FOUCRET?= Date: Wed, 17 Jan 2024 08:10:48 +0100 Subject: [PATCH] [LTR] Rescore window size improvements. (#104318) --- .../search/rescore/RescorerBuilder.java | 23 ++- .../inference/ltr/LearningToRankRescorer.java | 9 + .../ltr/LearningToRankRescorerBuilder.java | 17 +- ...RankRescorerBuilderSerializationTests.java | 155 ++++++++---------- 4 files changed, 105 insertions(+), 99 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java index 76ee7e09ad87..4c42daba22b7 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java @@ -73,6 +73,8 @@ public static RescorerBuilder parseFromXContent(XContentParser parser, Consum RescorerBuilder rescorer = null; Integer windowSize = null; XContentParser.Token token; + String rescorerType = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { fieldName = parser.currentName(); @@ -83,8 +85,11 @@ public static RescorerBuilder parseFromXContent(XContentParser parser, Consum throw new ParsingException(parser.getTokenLocation(), "rescore doesn't support [" + fieldName + "]"); } } else if (token == XContentParser.Token.START_OBJECT) { - rescorer = parser.namedObject(RescorerBuilder.class, fieldName, null); - rescorerNameConsumer.accept(fieldName); + if (fieldName != null) { + rescorer = parser.namedObject(RescorerBuilder.class, fieldName, null); + rescorerNameConsumer.accept(fieldName); + rescorerType = fieldName; + } } else { throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "] after [" + fieldName + "]"); } @@ -92,9 +97,13 @@ public static RescorerBuilder parseFromXContent(XContentParser parser, Consum if (rescorer == null) { throw new ParsingException(parser.getTokenLocation(), "missing rescore type"); } + if (windowSize != null) { rescorer.windowSize(windowSize.intValue()); + } else if (rescorer.isWindowSizeRequired()) { + throw new ParsingException(parser.getTokenLocation(), "window_size is required for rescorer of type [" + rescorerType + "]"); } + return rescorer; } @@ -111,11 +120,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws protected abstract void doXContent(XContentBuilder builder, Params params) throws IOException; + /** + * Indicate if the window_size is a required parameter for the rescorer. + */ + protected boolean isWindowSizeRequired() { + return false; + } + /** * Build the {@linkplain RescoreContext} that will be used to actually * execute the rescore against a particular shard. */ public final RescoreContext buildContext(SearchExecutionContext context) throws IOException { + if (isWindowSizeRequired()) { + assert windowSize != null; + } int finalWindowSize = windowSize == null ? DEFAULT_WINDOW_SIZE : windowSize; RescoreContext rescoreContext = innerBuildContext(finalWindowSize, context); return rescoreContext; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java index 068462bcdfca..4e3fa3addaf3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java @@ -55,6 +55,15 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r if (ltrRescoreContext.regressionModelDefinition == null) { throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?"); } + + if (rescoreContext.getWindowSize() < topDocs.scoreDocs.length) { + throw new IllegalArgumentException( + "Rescore window is too small and should be at least the value of from + size but was [" + + rescoreContext.getWindowSize() + + "]" + ); + } + LocalModel definition = ltrRescoreContext.regressionModelDefinition; // First take top slice of incoming docs, to be rescored: diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java index 038f3fb08adb..058929e8379d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java @@ -32,10 +32,10 @@ public class LearningToRankRescorerBuilder extends RescorerBuilder { - public static final String NAME = "learning_to_rank"; - private static final ParseField MODEL_FIELD = new ParseField("model_id"); - private static final ParseField PARAMS_FIELD = new ParseField("params"); - private static final ObjectParser PARSER = new ObjectParser<>(NAME, false, Builder::new); + public static final ParseField NAME = new ParseField("learning_to_rank"); + public static final ParseField MODEL_FIELD = new ParseField("model_id"); + public static final ParseField PARAMS_FIELD = new ParseField("params"); + private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), false, Builder::new); static { PARSER.declareString(Builder::setModelId, MODEL_FIELD); @@ -251,7 +251,7 @@ protected LearningToRankRescorerContext innerBuildContext(int windowSize, Search @Override public String getWriteableName() { - return NAME; + return NAME.getPreferredName(); } @Override @@ -260,6 +260,11 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); } + @Override + protected boolean isWindowSizeRequired() { + return true; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { assert localModel == null || rescoreOccurred : "Unnecessarily populated local model object"; @@ -270,7 +275,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { @Override protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); + builder.startObject(NAME.getPreferredName()); builder.field(MODEL_FIELD.getPreferredName(), modelId); if (this.params != null) { builder.field(PARAMS_FIELD.getPreferredName(), this.params); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderSerializationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderSerializationTests.java index 79044a465442..f52d05fc3220 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderSerializationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderSerializationTests.java @@ -9,14 +9,19 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Tuple; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig; @@ -25,48 +30,36 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import static org.elasticsearch.search.rank.RankBuilder.WINDOW_SIZE_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfigTests.randomLearningToRankConfig; +import static org.hamcrest.Matchers.equalTo; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class LearningToRankRescorerBuilderSerializationTests extends AbstractBWCSerializationTestCase { private static LearningToRankService learningToRankService = mock(LearningToRankService.class); - @Override - protected LearningToRankRescorerBuilder doParseInstance(XContentParser parser) throws IOException { - String fieldName = null; - LearningToRankRescorerBuilder rescorer = null; - Integer windowSize = null; - XContentParser.Token token = parser.nextToken(); - assert token == XContentParser.Token.START_OBJECT; - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - fieldName = parser.currentName(); - } else if (token.isValue()) { - if (WINDOW_SIZE_FIELD.match(fieldName, parser.getDeprecationHandler())) { - windowSize = parser.intValue(); - } else { - throw new ParsingException(parser.getTokenLocation(), "rescore doesn't support [" + fieldName + "]"); + public void testRequiredWindowSize() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) { + LearningToRankRescorerBuilder testInstance = createTestInstance(); + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + testInstance.doXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + try (XContentParser parser = JsonXContent.jsonXContent.createParser(parserConfig(), Strings.toString(builder))) { + ParsingException e = expectThrows(ParsingException.class, () -> RescorerBuilder.parseFromXContent(parser, (r) -> {})); + assertThat(e.getMessage(), equalTo("window_size is required for rescorer of type [learning_to_rank]")); } - } else if (token == XContentParser.Token.START_OBJECT) { - rescorer = LearningToRankRescorerBuilder.fromXContent(parser, learningToRankService); - } else { - throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "] after [" + fieldName + "]"); } } - if (rescorer == null) { - throw new ParsingException(parser.getTokenLocation(), "missing rescore type"); - } - if (windowSize != null) { - rescorer.windowSize(windowSize); - } - return rescorer; + } + + @Override + protected LearningToRankRescorerBuilder doParseInstance(XContentParser parser) throws IOException { + return (LearningToRankRescorerBuilder) RescorerBuilder.parseFromXContent(parser, (r) -> {}); } @Override @@ -85,76 +78,49 @@ protected LearningToRankRescorerBuilder createTestInstance() { learningToRankService ); - if (randomBoolean()) { - builder.windowSize(randomIntBetween(1, 10000)); - } + builder.windowSize(randomIntBetween(1, 10000)); return builder; } @Override protected LearningToRankRescorerBuilder createXContextTestInstance(XContentType xContentType) { - return new LearningToRankRescorerBuilder(randomAlphaOfLength(10), randomBoolean() ? randomParams() : null, learningToRankService); + return new LearningToRankRescorerBuilder(randomAlphaOfLength(10), randomBoolean() ? randomParams() : null, learningToRankService) + .windowSize(randomIntBetween(1, 10000)); } @Override protected LearningToRankRescorerBuilder mutateInstance(LearningToRankRescorerBuilder instance) throws IOException { - int i = randomInt(4); return switch (i) { - case 0 -> { - LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder( - randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(10)), - instance.params(), - learningToRankService - ); - if (instance.windowSize() != null) { - builder.windowSize(instance.windowSize()); - } - yield builder; - } + case 0 -> new LearningToRankRescorerBuilder( + randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(10)), + instance.params(), + learningToRankService + ).windowSize(instance.windowSize()); case 1 -> new LearningToRankRescorerBuilder(instance.modelId(), instance.params(), learningToRankService).windowSize( randomValueOtherThan(instance.windowSize(), () -> randomIntBetween(1, 10000)) ); - case 2 -> { - LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder( - instance.modelId(), - randomValueOtherThan(instance.params(), () -> (randomBoolean() ? randomParams() : null)), - learningToRankService - ); - if (instance.windowSize() != null) { - builder.windowSize(instance.windowSize() + 1); - } - yield builder; - } + case 2 -> new LearningToRankRescorerBuilder( + instance.modelId(), + randomValueOtherThan(instance.params(), () -> (randomBoolean() ? randomParams() : null)), + learningToRankService + ).windowSize(instance.windowSize()); case 3 -> { LearningToRankConfig learningToRankConfig = randomValueOtherThan( instance.learningToRankConfig(), () -> randomLearningToRankConfig() ); - LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder( - instance.modelId(), - learningToRankConfig, - null, - learningToRankService + yield new LearningToRankRescorerBuilder(instance.modelId(), learningToRankConfig, null, learningToRankService).windowSize( + instance.windowSize() ); - if (instance.windowSize() != null) { - builder.windowSize(instance.windowSize()); - } - yield builder; - } - case 4 -> { - LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder( - mock(LocalModel.class), - instance.learningToRankConfig(), - instance.params(), - learningToRankService - ); - if (instance.windowSize() != null) { - builder.windowSize(instance.windowSize()); - } - yield builder; } + case 4 -> new LearningToRankRescorerBuilder( + mock(LocalModel.class), + instance.learningToRankConfig(), + instance.params(), + learningToRankService + ).windowSize(instance.windowSize()); default -> throw new AssertionError("Unexpected random test case"); }; } @@ -169,31 +135,38 @@ protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()); + namedXContent.add( + new NamedXContentRegistry.Entry( + RescorerBuilder.class, + LearningToRankRescorerBuilder.NAME, + (p, c) -> LearningToRankRescorerBuilder.fromXContent(p, learningToRankService) + ) + ); return new NamedXContentRegistry(namedXContent); } + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return writableRegistry(); + } + @Override protected NamedWriteableRegistry writableRegistry() { List namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables()); namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables()); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + RescorerBuilder.class, + LearningToRankRescorerBuilder.NAME.getPreferredName(), + in -> new LearningToRankRescorerBuilder(in, learningToRankService) + ) + ); return new NamedWriteableRegistry(namedWriteables); } - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return writableRegistry(); - } - private static Map randomParams() { return randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier())); } - - private static LocalModel localModelMock() { - LocalModel model = mock(LocalModel.class); - String modelId = randomIdentifier(); - when(model.getModelId()).thenReturn(modelId); - return model; - } }