From 01e400e9509c3b0fb39bff684c05601acf253eaa Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Wed, 15 Nov 2023 19:59:38 +0100 Subject: [PATCH] Add parameters to the learn_to_rank rescorer. --- .../inference/MlLTRNamedXContentProvider.java | 3 +- .../trainedmodel/LearnToRankConfig.java | 1 + .../xpack/core/ml/utils/QueryProvider.java | 9 +- .../xpack/ml/integration/MlRescorerIT.java | 59 +-- ...orerIT.java => LearnToRankRescorerIT.java} | 96 ++-- .../xpack/ml/MachineLearning.java | 14 +- .../rescorer/InferenceRescorerBuilder.java | 392 ---------------- .../rescorer/InferenceRescorerFeature.java | 2 +- ...Rescorer.java => LearnToRankRescorer.java} | 10 +- .../rescorer/LearnToRankRescorerBuilder.java | 418 ++++++++++++++++++ ...t.java => LearnToRankRescorerContext.java} | 4 +- ...arnToRankRescorerBuilderRewriteTests.java} | 140 +++--- ...ankRescorerBuilderSerializationTests.java} | 99 +++-- 13 files changed, 653 insertions(+), 594 deletions(-) rename x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/{InferenceRescorerIT.java => LearnToRankRescorerIT.java} (80%) delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/{InferenceRescorer.java => LearnToRankRescorer.java} (94%) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/{InferenceRescorerContext.java => LearnToRankRescorerContext.java} (97%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/{InferenceRescorerBuilderRewriteTests.java => LearnToRankRescorerBuilderRewriteTests.java} (67%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/{InferenceRescorerBuilderSerializationTests.java => LearnToRankRescorerBuilderSerializationTests.java} (58%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java index b878aad27a42e..6f78077e16288 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; @@ -59,7 +58,7 @@ public List getNamedWriteables() { List namedWriteables = new ArrayList<>(); // Inference config namedWriteables.add( - new NamedWriteableRegistry.Entry(InferenceConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new) + new NamedWriteableRegistry.Entry(LearnToRankConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new) ); // LTR Extractors namedWriteables.add( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java index 25a2055e00f68..b6a1807151354 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java @@ -32,6 +32,7 @@ public class LearnToRankConfig extends RegressionConfig implements Rewriteable LENIENT_PARSER = createParser(true); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java index da50b1eb64b50..05e47fbc09e2a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java @@ -31,9 +31,9 @@ public class QueryProvider implements Writeable, ToXContentObject, Rewriteable query; + private final Exception parsingException; + private final QueryBuilder parsedQuery; + private final Map query; public static QueryProvider defaultQuery() { return new QueryProvider( @@ -77,7 +77,7 @@ public static QueryProvider fromStream(StreamInput in) throws IOException { return new QueryProvider(in.readMap(), in.readOptionalNamedWriteable(QueryBuilder.class), in.readException()); } - QueryProvider(Map query, QueryBuilder parsedQuery, Exception parsingException) { + public QueryProvider(Map query, QueryBuilder parsedQuery, Exception parsingException) { this.query = Collections.unmodifiableMap(new LinkedHashMap<>(Objects.requireNonNull(query, "[query] must not be null"))); this.parsedQuery = parsedQuery; this.parsingException = parsingException; @@ -136,7 +136,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public QueryProvider rewrite(QueryRewriteContext ctx) throws IOException { - assert parsedQuery != null; if (parsedQuery == null) { return this; } diff --git a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java index f3ac67f338eee..7aa45c0ab534a 100644 --- a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java +++ b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java @@ -34,12 +34,20 @@ public void setupModelAndData() throws IOException { "input": { "field_names": ["cost", "product"] }, "inference_config": { "learn_to_rank": { - "feature_extractors": [{ - "query_extractor": { - "feature_name": "two", - "query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 2.0;" } } } + "feature_extractors": [ + { + "query_extractor": { + "feature_name": "two", + "query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 2.0;" } } } + } + }, + { + "query_extractor": { + "feature_name": "product_bm25", + "query": { "term": { "product": "{{keyword}}" } } + } } - }] + ] } }, "definition": { @@ -174,7 +182,8 @@ public void setupModelAndData() throws IOException { } } } - }"""); + } + """); createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", randomIntBetween(1, 3)).build(), """ "properties":{ "product":{ "type": "keyword" }, @@ -197,7 +206,7 @@ public void testLtrSimple() throws Exception { }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model" } } @@ -209,8 +218,6 @@ public void testLtrSimple() throws Exception { @SuppressWarnings("unchecked") public void testLtrSimpleDFS() throws Exception { - skipTestUntilParametersAreImplemented(); - Response searchResponse = searchDfs(""" { "query": { @@ -218,18 +225,11 @@ public void testLtrSimpleDFS() throws Exception { }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model", - "inference_config": { - "learn_to_rank": { - "feature_extractors":[ - { "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "TV" } } } } - ] - } - } + "params": { "keyword": "TV" } } } - }"""); Map response = responseAsMap(searchResponse); @@ -239,16 +239,9 @@ public void testLtrSimpleDFS() throws Exception { { "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model", - "inference_config": { - "learn_to_rank": { - "feature_extractors":[ - { "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "TV" } } } } - ] - } - } - } + "params": { "keyword": "TV" } } }"""); @@ -269,7 +262,7 @@ public void testLtrSimpleEmpty() throws Exception { }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model" } } @@ -288,7 +281,7 @@ public void testLtrEmptyDFS() throws Exception { }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model" } } @@ -307,7 +300,7 @@ public void testLtrCanMatch() throws Exception { }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model" } } @@ -323,7 +316,7 @@ public void testLtrCanMatch() throws Exception { }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "basic-ltr-model" } } @@ -365,8 +358,4 @@ private void putRegressionModel(String modelId, String body) throws IOException model.setJsonEntity(body); assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); } - - private void skipTestUntilParametersAreImplemented() { - // throw new AssumptionViolatedException("Skip the test until parameters are implemented"); - } } diff --git a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearnToRankRescorerIT.java similarity index 80% rename from x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java rename to x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearnToRankRescorerIT.java index 600e656f6cf48..d246f070f0b8d 100644 --- a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java +++ b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearnToRankRescorerIT.java @@ -18,7 +18,7 @@ import static org.hamcrest.Matchers.equalTo; -public class InferenceRescorerIT extends InferenceTestCase { +public class LearnToRankRescorerIT extends InferenceTestCase { private static final String MODEL_ID = "ltr-model"; private static final String INDEX_NAME = "store"; @@ -26,32 +26,38 @@ public class InferenceRescorerIT extends InferenceTestCase { @Before public void setupModelAndData() throws IOException { putRegressionModel(MODEL_ID, """ - { - "description": "super complex model for tests", - "input": {"field_names": ["cost", "product"]}, - "inference_config": { - "learn_to_rank": { - "feature_extractors": [ - { - "query_extractor": { - "feature_name": "two", - "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}} - } - } - ] - } - }, - "definition": { - "preprocessors" : [{ - "one_hot_encoding": { - "field": "product", - "hot_map": { - "TV": "type_tv", - "VCR": "type_vcr", - "Laptop": "type_laptop" - } - } - }], + { + "description": "super complex model for tests", + "input": {"field_names": ["cost", "product"]}, + "inference_config": { + "learn_to_rank": { + "feature_extractors": [ + { + "query_extractor": { + "feature_name": "two", + "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}} + } + }, + { + "query_extractor": { + "feature_name": "product_bm25", + "query": {"term": {"product": "{{keyword}}"}} + } + } + ] + } + }, + "definition": { + "preprocessors" : [{ + "one_hot_encoding": { + "field": "product", + "hot_map": { + "TV": "type_tv", + "VCR": "type_vcr", + "Laptop": "type_laptop" + } + } + }], "trained_model": { "ensemble": { "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"], @@ -171,7 +177,8 @@ public void setupModelAndData() throws IOException { } } } - }"""); + } + """); createIndex(INDEX_NAME, Settings.EMPTY, """ "properties":{ "product":{"type": "keyword"}, @@ -189,16 +196,13 @@ public void setupModelAndData() throws IOException { adminClient().performRequest(new Request("POST", INDEX_NAME + "/_refresh")); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98372") - public void testInferenceRescore() throws Exception { - skipTestUntilParametersAreImplemented(); - + public void testLearnToRankRescore() throws Exception { Request request = new Request("GET", "store/_search?size=3&error_trace"); request.setJsonEntity(""" { "rescore": { "window_size": 10, - "inference": { "model_id": "ltr-model" } + "learn_to_rank": { "model_id": "ltr-model" } } }"""); assertHitScores(client().performRequest(request), List.of(20.0, 20.0, 17.0)); @@ -207,14 +211,10 @@ public void testInferenceRescore() throws Exception { "query": { "term": { "product": "Laptop" } }, "rescore": { "window_size": 10, - "inference": { + "learn_to_rank": { "model_id": "ltr-model", - "inference_config": { - "learn_to_rank": { - "feature_extractors":[ - { "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "Laptop"} } } } - ] - } + "params": { + "keyword": "Laptop" } } } @@ -225,27 +225,25 @@ public void testInferenceRescore() throws Exception { "query": {"term": { "product": "Laptop" } }, "rescore": { "window_size": 10, - "inference": { "model_id": "ltr-model"} + "learn_to_rank": { "model_id": "ltr-model"} } }"""); assertHitScores(client().performRequest(request), List.of(9.0, 9.0, 6.0)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98372") - public void testInferenceRescoreSmallWindow() throws Exception { + public void testLearnToRankRescoreSmallWindow() throws Exception { Request request = new Request("GET", "store/_search?size=5"); request.setJsonEntity(""" { "rescore": { "window_size": 2, - "inference": { "model_id": "ltr-model" } + "learn_to_rank": { "model_id": "ltr-model" } } }"""); assertHitScores(client().performRequest(request), List.of(20.0, 20.0, 1.0, 1.0, 1.0)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98372") - public void testInferenceRescorerWithChainedRescorers() throws IOException { + public void testLearnToRankRescorerWithChainedRescorers() throws IOException { Request request = new Request("GET", "store/_search?size=5"); request.setJsonEntity(""" { @@ -256,7 +254,7 @@ public void testInferenceRescorerWithChainedRescorers() throws IOException { }, { "window_size": 3, - "inference": { "model_id": "ltr-model" } + "learn_to_rank": { "model_id": "ltr-model" } }, { "window_size": 2, @@ -277,8 +275,4 @@ private void indexData(String data) throws IOException { private static void assertHitScores(Response response, List expectedScores) throws IOException { assertThat((List) XContentMapValues.extractValue("hits.hits._score", responseAsMap(response)), equalTo(expectedScores)); } - - private void skipTestUntilParametersAreImplemented() { - // throw new AssumptionViolatedException("Skip the test until parameters are implemented"); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index f4bce4906c0b0..487460f3710c6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -72,6 +72,7 @@ import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.script.ScriptService; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -327,8 +328,8 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.BlackHolePyTorchProcess; import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory; import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory; -import org.elasticsearch.xpack.ml.inference.rescorer.InferenceRescorerBuilder; import org.elasticsearch.xpack.ml.inference.rescorer.InferenceRescorerFeature; +import org.elasticsearch.xpack.ml.inference.rescorer.LearnToRankRescorerBuilder; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.NodeLoadDetector; @@ -731,7 +732,6 @@ public void loadExtensions(ExtensionLoader loader) { private final Settings settings; private final boolean enabled; - private final SetOnce autodetectProcessManager = new SetOnce<>(); private final SetOnce datafeedConfigProvider = new SetOnce<>(); private final SetOnce datafeedRunner = new SetOnce<>(); @@ -745,7 +745,7 @@ public void loadExtensions(ExtensionLoader loader) { private final SetOnce mlAutoscalingDeciderService = new SetOnce<>(); private final SetOnce deploymentManager = new SetOnce<>(); private final SetOnce trainedModelAllocationClusterServiceSetOnce = new SetOnce<>(); - + private final SetOnce scriptService = new SetOnce<>(); private final SetOnce machineLearningExtension = new SetOnce<>(); public MachineLearning(Settings settings) { @@ -864,12 +864,13 @@ private static void reportClashingNodeAttribute(String attrName) { @Override public List> getRescorers() { if (enabled && InferenceRescorerFeature.isEnabled()) { + // Inference rescorer requires access to the model loading service return List.of( new RescorerSpec<>( - InferenceRescorerBuilder.NAME, - in -> new InferenceRescorerBuilder(in, modelLoadingService::get), - parser -> InferenceRescorerBuilder.fromXContent(parser, modelLoadingService::get) + LearnToRankRescorerBuilder.NAME, + in -> new LearnToRankRescorerBuilder(in, modelLoadingService::get, scriptService::get), + parser -> LearnToRankRescorerBuilder.fromXContent(parser, modelLoadingService::get, scriptService::get) ) ); } @@ -893,6 +894,7 @@ public Collection createComponents(PluginServices services) { machineLearningExtension.get().configure(environment.settings()); + this.scriptService.set(services.scriptService()); this.mlUpgradeModeActionFilter.set(new MlUpgradeModeActionFilter(clusterService)); MlIndexTemplateRegistry registry = new MlIndexTemplateRegistry( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java deleted file mode 100644 index 5b110e2b42da7..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java +++ /dev/null @@ -1,392 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.inference.rescorer; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.Rewriteable; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.search.rescore.RescorerBuilder; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; -import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; - -import java.io.IOException; -import java.util.Objects; -import java.util.Optional; -import java.util.function.Supplier; - -public class InferenceRescorerBuilder extends RescorerBuilder { - - public static final String NAME = "inference"; - private static final ParseField MODEL = new ParseField("model_id"); - private static final ParseField INTERNAL_INFERENCE_CONFIG = new ParseField("_internal_inference_config"); - private static final ObjectParser PARSER = new ObjectParser<>(NAME, false, Builder::new); - static { - PARSER.declareString(Builder::setModelId, MODEL); - PARSER.declareNamedObject( - Builder::setInferenceConfig, - (p, c, name) -> p.namedObject(StrictlyParsedInferenceConfig.class, name, false), - INTERNAL_INFERENCE_CONFIG - ); - } - - public static InferenceRescorerBuilder fromXContent(XContentParser parser, Supplier modelLoadingServiceSupplier) { - return PARSER.apply(parser, null).build(modelLoadingServiceSupplier); - } - - private final String modelId; - private final LearnToRankConfig inferenceConfig; - private final LocalModel inferenceDefinition; - private final Supplier inferenceDefinitionSupplier; - private final Supplier modelLoadingServiceSupplier; - private final Supplier inferenceConfigSupplier; - private boolean rescoreOccurred; - - public InferenceRescorerBuilder(String modelId, Supplier modelLoadingServiceSupplier) { - this.modelId = Objects.requireNonNull(modelId); - this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; - this.inferenceDefinition = null; - this.inferenceDefinitionSupplier = null; - this.inferenceConfigSupplier = null; - this.inferenceConfig = null; - } - - InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, Supplier modelLoadingServiceSupplier) { - this.modelId = Objects.requireNonNull(modelId); - this.inferenceDefinition = null; - this.inferenceDefinitionSupplier = null; - this.inferenceConfigSupplier = null; - this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; - this.inferenceConfig = Objects.requireNonNull(inferenceConfig); - } - - private InferenceRescorerBuilder( - String modelId, - Supplier modelLoadingServiceSupplier, - Supplier inferenceConfigSupplier - ) { - this.modelId = Objects.requireNonNull(modelId); - this.inferenceDefinition = null; - this.inferenceDefinitionSupplier = null; - this.inferenceConfigSupplier = inferenceConfigSupplier; - this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; - this.inferenceConfig = null; - } - - private InferenceRescorerBuilder( - String modelId, - LearnToRankConfig inferenceConfig, - Supplier modelLoadingServiceSupplier, - Supplier inferenceDefinitionSupplier - ) { - this.modelId = modelId; - this.inferenceDefinition = null; - this.inferenceDefinitionSupplier = inferenceDefinitionSupplier; - this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; - this.inferenceConfigSupplier = null; - this.inferenceConfig = inferenceConfig; - } - - InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, LocalModel inferenceDefinition) { - this.modelId = modelId; - this.inferenceDefinition = inferenceDefinition; - this.inferenceDefinitionSupplier = null; - this.modelLoadingServiceSupplier = null; - this.inferenceConfigSupplier = null; - this.inferenceConfig = inferenceConfig; - } - - public InferenceRescorerBuilder(StreamInput input, Supplier modelLoadingServiceSupplier) throws IOException { - super(input); - this.modelId = input.readString(); - this.inferenceDefinitionSupplier = null; - this.inferenceConfigSupplier = null; - this.inferenceDefinition = null; - this.inferenceConfig = (LearnToRankConfig) input.readOptionalNamedWriteable(InferenceConfig.class); - this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; - } - - @Override - public String getWriteableName() { - return NAME; - } - - /** - * should be updated once {@link InferenceRescorerFeature} is removed - */ - @Override - public TransportVersion getMinimalSupportedVersion() { - // TODO: update transport version when released! - return TransportVersion.current(); - } - - /** - * Here we fetch the stored model inference context, apply the given update, and rewrite. - * - * This can and be done on the coordinator as it not only validates if the stored model is of the appropriate type, it allows - * any stored logic to rewrite on the coordinator level if possible. - * @param ctx QueryRewriteContext - * @return rewritten InferenceRescorerBuilder or self if no changes - * @throws IOException when rewrite fails - */ - private RescorerBuilder doRewrite(QueryRewriteContext ctx) throws IOException { - // Awaiting fetch - if (inferenceConfigSupplier != null && inferenceConfigSupplier.get() == null) { - return this; - } - if (inferenceConfig != null) { - LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(inferenceConfig, ctx); - if (rewrittenConfig == inferenceConfig) { - return this; - } - InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, rewrittenConfig, modelLoadingServiceSupplier); - if (windowSize != null) { - builder.windowSize(windowSize); - } - return builder; - } - // We have requested for the stored config and fetch is completed, get the config and rewrite further if required - if (inferenceConfigSupplier != null) { - LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(inferenceConfigSupplier.get(), ctx); - InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, rewrittenConfig, modelLoadingServiceSupplier); - if (windowSize != null) { - builder.windowSize(windowSize); - } - return builder; - } - SetOnce configSetOnce = new SetOnce<>(); - GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); - request.setAllowNoResources(false); - ctx.registerAsyncAction( - (c, l) -> ClientHelper.executeAsyncWithOrigin( - c, - ClientHelper.ML_ORIGIN, - GetTrainedModelsAction.INSTANCE, - request, - ActionListener.wrap(trainedModels -> { - TrainedModelConfig config = trainedModels.getResources().results().get(0); - if (config.getInferenceConfig() instanceof LearnToRankConfig retrievedInferenceConfig) { - // TODO: apply params instead of an override. - for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) { - builder.validate(); - } - configSetOnce.set(retrievedInferenceConfig); - l.onResponse(null); - return; - } - l.onFailure( - ExceptionsHelper.badRequestException( - Messages.getMessage( - Messages.INFERENCE_CONFIG_INCORRECT_TYPE, - Optional.ofNullable(config.getInferenceConfig()).map(InferenceConfig::getName).orElse("null"), - LearnToRankConfig.NAME.getPreferredName() - ) - ) - ); - }, l::onFailure) - ) - ); - InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, modelLoadingServiceSupplier, configSetOnce::get); - if (windowSize() != null) { - builder.windowSize(windowSize); - } - return builder; - } - - /** - * This rewrite phase occurs on the data node when we know we will want to use the model for inference - * @param ctx Rewrite context - * @return A rewritten rescorer with a model definition or a model definition supplier populated - */ - private RescorerBuilder doDataNodeRewrite(QueryRewriteContext ctx) { - assert inferenceConfig != null; - // We already have an inference definition, no need to do any rewriting - if (inferenceDefinition != null) { - return this; - } - // Awaiting fetch - if (inferenceDefinitionSupplier != null && inferenceDefinitionSupplier.get() == null) { - return this; - } - if (inferenceDefinitionSupplier != null) { - LocalModel inferenceDefinition = inferenceDefinitionSupplier.get(); - InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, inferenceConfig, inferenceDefinition); - if (windowSize() != null) { - builder.windowSize(windowSize()); - } - return builder; - } - if (modelLoadingServiceSupplier == null || modelLoadingServiceSupplier.get() == null) { - throw new IllegalStateException("Model loading service must be available"); - } - SetOnce inferenceDefinitionSetOnce = new SetOnce<>(); - ctx.registerAsyncAction((c, l) -> modelLoadingServiceSupplier.get().getModelForLearnToRank(modelId, ActionListener.wrap(lm -> { - inferenceDefinitionSetOnce.set(lm); - l.onResponse(null); - }, l::onFailure))); - InferenceRescorerBuilder builder = new InferenceRescorerBuilder( - modelId, - inferenceConfig, - modelLoadingServiceSupplier, - inferenceDefinitionSetOnce::get - ); - if (windowSize() != null) { - builder.windowSize(windowSize()); - } - return builder; - } - - /** - * This rewrite phase occurs on the data node when we know we will want to use the model for inference - * @param ctx Rewrite context - * @return A rewritten rescorer with a model definition or a model definition supplier populated - * @throws IOException If fetching, parsing, or overall rewrite failures occur - */ - private RescorerBuilder doSearchRewrite(QueryRewriteContext ctx) throws IOException { - if (inferenceConfig == null) { - return this; - } - LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(inferenceConfig, ctx); - if (rewrittenConfig == inferenceConfig) { - return this; - } - InferenceRescorerBuilder builder = inferenceDefinition == null - ? new InferenceRescorerBuilder(modelId, rewrittenConfig, modelLoadingServiceSupplier) - : new InferenceRescorerBuilder(modelId, rewrittenConfig, inferenceDefinition); - if (windowSize != null) { - builder.windowSize(windowSize); - } - return builder; - } - - @Override - public RescorerBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (ctx.convertToDataRewriteContext() != null) { - return doDataNodeRewrite(ctx); - } - if (ctx.convertToSearchExecutionContext() != null) { - return doSearchRewrite(ctx); - } - return doRewrite(ctx); - } - - public String getModelId() { - return modelId; - } - - LearnToRankConfig getInferenceConfig() { - return inferenceConfig; - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - if (inferenceDefinitionSupplier != null || inferenceConfigSupplier != null) { - throw new IllegalStateException("suppliers must be null, missing a rewriteAndFetch?"); - } - assert inferenceDefinition == null || rescoreOccurred : "Unnecessarily populated local model object"; - out.writeString(modelId); - out.writeOptionalNamedWriteable(inferenceConfig); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); - builder.field(MODEL.getPreferredName(), modelId); - if (inferenceConfig != null) { - NamedXContentObjectHelper.writeNamedObject(builder, params, INTERNAL_INFERENCE_CONFIG.getPreferredName(), inferenceConfig); - } - builder.endObject(); - } - - @Override - protected InferenceRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) { - rescoreOccurred = true; - return new InferenceRescorerContext(windowSize, InferenceRescorer.INSTANCE, inferenceConfig, inferenceDefinition, context); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - if (super.equals(o) == false) return false; - InferenceRescorerBuilder that = (InferenceRescorerBuilder) o; - return Objects.equals(modelId, that.modelId) - && Objects.equals(inferenceDefinition, that.inferenceDefinition) - && Objects.equals(inferenceConfig, that.inferenceConfig) - && Objects.equals(inferenceDefinitionSupplier, that.inferenceDefinitionSupplier) - && Objects.equals(modelLoadingServiceSupplier, that.modelLoadingServiceSupplier); - } - - @Override - public int hashCode() { - return Objects.hash( - super.hashCode(), - modelId, - inferenceConfig, - inferenceDefinition, - inferenceDefinitionSupplier, - modelLoadingServiceSupplier - ); - } - - // Used in tests - Supplier modelLoadingServiceSupplier() { - return modelLoadingServiceSupplier; - } - - // Used in tests - LocalModel getInferenceDefinition() { - return inferenceDefinition; - } - - static class Builder { - private String modelId; - private LearnToRankConfig inferenceConfig; - - public void setModelId(String modelId) { - this.modelId = modelId; - } - - void setInferenceConfig(InferenceConfig inferenceConfig) { - if (inferenceConfig instanceof LearnToRankConfig learnToRankConfig) { - this.inferenceConfig = learnToRankConfig; - return; - } - throw new IllegalArgumentException( - Strings.format( - "[%s] only allows a [%s] object to be configured", - INTERNAL_INFERENCE_CONFIG.getPreferredName(), - LearnToRankConfig.NAME.getPreferredName() - ) - ); - } - - InferenceRescorerBuilder build(Supplier modelLoadingServiceSupplier) { - return new InferenceRescorerBuilder(modelId, inferenceConfig, modelLoadingServiceSupplier); - } - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerFeature.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerFeature.java index 2b88faa3e4c14..4d2d507b799b6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerFeature.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerFeature.java @@ -14,7 +14,7 @@ * * Upon removal, ensure transport serialization is all corrected for future BWC. * - * See {@link InferenceRescorerBuilder} + * See {@link LearnToRankRescorerBuilder} */ public class InferenceRescorerFeature { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorer.java similarity index 94% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorer.java index df3e0756ea39a..ab1dd5e8873ec 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorer.java @@ -32,17 +32,17 @@ import static java.util.stream.Collectors.toUnmodifiableSet; -public class InferenceRescorer implements Rescorer { +public class LearnToRankRescorer implements Rescorer { - public static final InferenceRescorer INSTANCE = new InferenceRescorer(); - private static final Logger logger = LogManager.getLogger(InferenceRescorer.class); + public static final LearnToRankRescorer INSTANCE = new LearnToRankRescorer(); + private static final Logger logger = LogManager.getLogger(LearnToRankRescorer.class); private static final Comparator SCORE_DOC_COMPARATOR = (o1, o2) -> { int cmp = Float.compare(o2.score, o1.score); return cmp == 0 ? Integer.compare(o1.doc, o2.doc) : cmp; }; - private InferenceRescorer() { + private LearnToRankRescorer() { } @@ -51,7 +51,7 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r if (topDocs.scoreDocs.length == 0) { return topDocs; } - InferenceRescorerContext ltrRescoreContext = (InferenceRescorerContext) rescoreContext; + LearnToRankRescorerContext ltrRescoreContext = (LearnToRankRescorerContext) rescoreContext; if (ltrRescoreContext.inferenceDefinition == null) { throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java new file mode 100644 index 0000000000000..494af0310953a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java @@ -0,0 +1,418 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.rescorer; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.Rewriteable; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.script.TemplateScript; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Supplier; + +import static org.elasticsearch.script.Script.DEFAULT_TEMPLATE_LANG; + +public class LearnToRankRescorerBuilder extends RescorerBuilder { + + public static final String NAME = "learn_to_rank"; + private static final ParseField MODEL_FIELD = new ParseField("model_id"); + private static final ParseField PARAMS_FIELD = new ParseField("params"); + private static final ObjectParser PARSER = new ObjectParser<>(NAME, false, Builder::new); + + static { + PARSER.declareString(Builder::setModelId, MODEL_FIELD); + PARSER.declareObject(Builder::setParams, (p, c) -> p.map(), PARAMS_FIELD); + } + + public static LearnToRankRescorerBuilder fromXContent( + XContentParser parser, + Supplier modelLoadingServiceSupplier, + Supplier scriptServiceSupplier + ) { + return PARSER.apply(parser, null).build(modelLoadingServiceSupplier, scriptServiceSupplier); + } + + private final String modelId; + private final Map params; + private final Supplier scriptServiceSupplier; + private final Supplier modelLoadingServiceSupplier; + private final Supplier localModelSupplier; + private final Supplier learnToRankConfigSupplier; + private boolean rescoreOccurred = false; + + LearnToRankRescorerBuilder( + String modelId, + Map params, + Supplier modelLoadingServiceSupplier, + Supplier scriptServiceSupplier + ) { + this.modelId = modelId; + this.params = params; + this.scriptServiceSupplier = scriptServiceSupplier; + this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; + + // Config and model will be set during successive rewrite phases. + this.learnToRankConfigSupplier = null; + this.localModelSupplier = null; + } + + LearnToRankRescorerBuilder( + String modelId, + Supplier modelLoadingServiceSupplier, + Supplier learnToRankConfigSupplier + ) { + this.modelId = modelId; + this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; + this.learnToRankConfigSupplier = learnToRankConfigSupplier; + + // Local inference model is not loaded yet. Will be done in a later rewrite. + this.localModelSupplier = null; + + // Templates has been applied already, so we do not need params and script service anymore. + this.params = null; + this.scriptServiceSupplier = null; + } + + LearnToRankRescorerBuilder(Supplier learnToRankConfigSupplier, Supplier localModelSupplier) { + this.modelId = localModelSupplier.get() != null ? localModelSupplier.get().getModelId() : null; + this.learnToRankConfigSupplier = learnToRankConfigSupplier; + this.localModelSupplier = localModelSupplier; + + // Model is loaded already, so we do not need the model loading service anymore. + this.modelLoadingServiceSupplier = null; + + // Template has been applied already, so we do not need params and script service anymore. + this.params = null; + this.scriptServiceSupplier = null; + } + + public LearnToRankRescorerBuilder( + StreamInput input, + Supplier modelLoadingServiceSupplier, + Supplier scriptServiceSupplier + ) throws IOException { + super(input); + this.modelId = input.readString(); + this.params = input.readMap(); + + LearnToRankConfig learnToRankConfig = input.readOptionalNamedWriteable(LearnToRankConfig.class); + this.learnToRankConfigSupplier = learnToRankConfig != null ? () -> learnToRankConfig : null; + + this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; + this.scriptServiceSupplier = scriptServiceSupplier; + + this.localModelSupplier = null; + } + + public String modelId() { + return modelId; + } + + public Map params() { + return params; + } + + public Supplier learnToRankConfigSupplier() { + return learnToRankConfigSupplier; + } + + public Supplier modelLoadingServiceSupplier() { + return modelLoadingServiceSupplier; + } + + public Supplier localModelSupplier() { + return localModelSupplier; + } + + @Override + public RescorerBuilder rewrite(QueryRewriteContext ctx) throws IOException { + if (ctx.convertToDataRewriteContext() != null) { + return doDataNodeRewrite(ctx); + } + if (ctx.convertToSearchExecutionContext() != null) { + return doSearchRewrite(ctx); + } + return doCoordinatorNodeRewrite(ctx); + } + + /** + * Here we fetch the stored model inference context, apply the given update, and rewrite. + * + * This can and be done on the coordinator as it not only validates if the stored model is of the appropriate type, it allows + * any stored logic to rewrite on the coordinator level if possible. + * @param ctx QueryRewriteContext + * @return rewritten LearnToRankRescorerBuilder or self if no changes + * @throws IOException when rewrite fails + */ + private RescorerBuilder doCoordinatorNodeRewrite(QueryRewriteContext ctx) throws IOException { + if (learnToRankConfigSupplier != null && learnToRankConfigSupplier.get() == null) { + // Awaiting to fetch the model. + return this; + } + + // We have requested for the stored config and fetch is completed, get the config and rewrite further if required + if (learnToRankConfigSupplier != null) { + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfigSupplier.get(), ctx); + if (rewrittenConfig == learnToRankConfigSupplier.get()) { + return this; + } + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + modelId, + modelLoadingServiceSupplier, + () -> rewrittenConfig + ); + if (windowSize != null) { + builder.windowSize(windowSize); + } + return builder; + } + + SetOnce configSetOnce = new SetOnce<>(); + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); + request.setAllowNoResources(false); + ctx.registerAsyncAction( + (c, l) -> ClientHelper.executeAsyncWithOrigin( + c, + ClientHelper.ML_ORIGIN, + GetTrainedModelsAction.INSTANCE, + request, + ActionListener.wrap(trainedModels -> { + TrainedModelConfig config = trainedModels.getResources().results().get(0); + if (config.getInferenceConfig() instanceof LearnToRankConfig retrievedInferenceConfig) { + for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) { + builder.validate(); + } + configSetOnce.set(applyParams(retrievedInferenceConfig, ctx)); + l.onResponse(null); + return; + } + l.onFailure( + ExceptionsHelper.badRequestException( + Messages.getMessage( + Messages.INFERENCE_CONFIG_INCORRECT_TYPE, + Optional.ofNullable(config.getInferenceConfig()).map(InferenceConfig::getName).orElse("null"), + LearnToRankConfig.NAME.getPreferredName() + ) + ) + ); + }, l::onFailure) + ) + ); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingServiceSupplier, configSetOnce::get); + if (windowSize() != null) { + builder.windowSize(windowSize); + } + return builder; + } + + private LearnToRankConfig applyParams(LearnToRankConfig config, QueryRewriteContext ctx) throws IOException { + if (scriptServiceSupplier.get().isLangSupported(DEFAULT_TEMPLATE_LANG) == false) { + return config; + } + + if (params == null || params.isEmpty()) { + return config; + } + + try (XContentBuilder configSourceBuilder = XContentBuilder.builder(XContentType.JSON.xContent())) { + String templateSource = BytesReference.bytes(config.toXContent(configSourceBuilder, EMPTY_PARAMS)).utf8ToString(); + if (templateSource.contains("{{") == false) { + return config; + } + Script script = new Script(ScriptType.INLINE, DEFAULT_TEMPLATE_LANG, templateSource, Collections.emptyMap()); + String parsedTemplate = scriptServiceSupplier.get().compile(script, TemplateScript.CONTEXT).newInstance(params).execute(); + + XContentParser parser = XContentType.JSON.xContent().createParser(ctx.getParserConfig(), parsedTemplate); + + return LearnToRankConfig.fromXContentStrict(parser); + } + } + + /** + * This rewrite phase occurs on the data node when we know we will want to use the model for inference + * @param ctx Rewrite context + * @return A rewritten rescorer with a model definition or a model definition supplier populated + */ + private RescorerBuilder doDataNodeRewrite(QueryRewriteContext ctx) throws IOException { + assert learnToRankConfigSupplier.get() != null; + + // The model supplier is already created, no need to rewrite further. + if (localModelSupplier != null) { + return this; + } + + if (modelLoadingServiceSupplier == null || modelLoadingServiceSupplier.get() == null) { + throw new IllegalStateException("Model loading service must be available"); + } + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfigSupplier.get(), ctx); + SetOnce inferenceDefinitionSetOnce = new SetOnce<>(); + ctx.registerAsyncAction((c, l) -> modelLoadingServiceSupplier.get().getModelForLearnToRank(modelId, ActionListener.wrap(lm -> { + inferenceDefinitionSetOnce.set(lm); + l.onResponse(null); + }, l::onFailure))); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(() -> rewrittenConfig, inferenceDefinitionSetOnce::get); + if (windowSize() != null) { + builder.windowSize(windowSize()); + } + return builder; + } + + /** + * This rewrite phase occurs on the data node when we know we will want to use the model for inference + * @param ctx Rewrite context + * @return A rewritten rescorer with a model definition or a model definition supplier populated + * @throws IOException If fetching, parsing, or overall rewrite failures occur + */ + private RescorerBuilder doSearchRewrite(QueryRewriteContext ctx) throws IOException { + if (learnToRankConfigSupplier == null || learnToRankConfigSupplier.get() == null) { + return this; + } + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfigSupplier.get(), ctx); + if (rewrittenConfig == learnToRankConfigSupplier.get()) { + return this; + } + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(() -> rewrittenConfig, localModelSupplier); + if (windowSize != null) { + builder.windowSize(windowSize); + } + return builder; + } + + @Override + protected LearnToRankRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) { + rescoreOccurred = true; + LearnToRankConfig learnToRankConfig = learnToRankConfigSupplier != null ? learnToRankConfigSupplier.get() : null; + LocalModel inferenceDefinition = localModelSupplier != null ? localModelSupplier.get() : null; + return new LearnToRankRescorerContext(windowSize, LearnToRankRescorer.INSTANCE, learnToRankConfig, inferenceDefinition, context); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + // TODO: update transport version when released! + return TransportVersion.current(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + assert localModelSupplier == null || rescoreOccurred : "Unnecessarily populated local model object"; + out.writeString(modelId); + out.writeGenericMap(params); + out.writeOptionalNamedWriteable(learnToRankConfigSupplier != null ? learnToRankConfigSupplier.get() : null); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field(MODEL_FIELD.getPreferredName(), modelId); + if (params != null) { + builder.field(PARAMS_FIELD.getPreferredName(), params); + } + builder.endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + LearnToRankRescorerBuilder that = (LearnToRankRescorerBuilder) o; + + if (learnToRankConfigSupplier != null + && (that.learnToRankConfigSupplier == null + || Objects.equals(learnToRankConfigSupplier.get(), that.learnToRankConfigSupplier.get()) == false)) { + return false; + } + + if (localModelSupplier != null + && (that.localModelSupplier == null || Objects.equals(localModelSupplier.get(), that.localModelSupplier.get()) == false)) { + return false; + } + + return Objects.equals(modelId, that.modelId) + && Objects.equals(params, that.params) + && Objects.equals(modelLoadingServiceSupplier, that.modelLoadingServiceSupplier) + && Objects.equals(scriptServiceSupplier, that.scriptServiceSupplier) + && rescoreOccurred == that.rescoreOccurred; + } + + boolean areSuppliersEquals(Supplier a, Supplier b) { + if (a != null && b != null) { + return Objects.equals(a.get(), b.get()); + } + + return a == b; + } + + @Override + public int hashCode() { + return Objects.hash( + super.hashCode(), + modelId, + params, + modelLoadingServiceSupplier != null ? modelLoadingServiceSupplier.get() : null, + scriptServiceSupplier != null ? scriptServiceSupplier.get() : null, + learnToRankConfigSupplier != null ? learnToRankConfigSupplier.get() : null, + localModelSupplier != null ? localModelSupplier.get() : null, + rescoreOccurred + ); + } + + static class Builder { + private String modelId; + private Map params = Collections.emptyMap(); + + public void setModelId(String modelId) { + this.modelId = modelId; + } + + public void setParams(Map params) { + this.params = params; + } + + LearnToRankRescorerBuilder build( + Supplier modelLoadingServiceSupplier, + Supplier scriptServiceSupplier + ) { + return new LearnToRankRescorerBuilder(modelId, params, modelLoadingServiceSupplier, scriptServiceSupplier); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerContext.java similarity index 97% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerContext.java index fed3effdc06f6..3409cf1a44a86 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerContext.java @@ -24,7 +24,7 @@ import java.util.ArrayList; import java.util.List; -public class InferenceRescorerContext extends RescoreContext { +public class LearnToRankRescorerContext extends RescoreContext { final SearchExecutionContext executionContext; final LocalModel inferenceDefinition; @@ -37,7 +37,7 @@ public class InferenceRescorerContext extends RescoreContext { * @param inferenceDefinition The local model inference definition, may be null during certain search phases. * @param executionContext The local shard search context */ - public InferenceRescorerContext( + public LearnToRankRescorerContext( int windowSize, Rescorer rescorer, LearnToRankConfig inferenceConfig, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java similarity index 67% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java index aec79919c1a50..a8b75353ba1e1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.test.AbstractBuilderTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -40,7 +41,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; @@ -49,6 +52,7 @@ import java.io.IOException; import java.lang.reflect.Method; +import java.util.Collections; import java.util.List; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -57,12 +61,12 @@ import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCase { +public class LearnToRankRescorerBuilderRewriteTests extends AbstractBuilderTestCase { private static final String GOOD_MODEL = "modelId"; private static final String BAD_MODEL = "badModel"; @@ -72,7 +76,15 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas .setEstimatedOperations(1) .setModelSize(2) .setModelType(TrainedModelType.TREE_ENSEMBLE) - .setInferenceConfig(new LearnToRankConfig(null, null)) + .setInferenceConfig( + new LearnToRankConfig( + 2, + List.of( + new QueryExtractorBuilder("feature_1", new QueryProvider(Collections.emptyMap(), null, null)), + new QueryExtractorBuilder("feature_2", new QueryProvider(Collections.emptyMap(), null, null)) + ) + ) + ) .build(); private static final TrainedModelConfig BAD_MODEL_CONFIG = TrainedModelConfig.builder() .setModelId(BAD_MODEL) @@ -85,20 +97,20 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas public void testMustRewrite() { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( GOOD_MODEL, - LearnToRankConfigTests.randomLearnToRankConfig(), - () -> testModelLoader + () -> testModelLoader, + () -> LearnToRankConfigTests.randomLearnToRankConfig() ); SearchExecutionContext context = createSearchExecutionContext(); - InferenceRescorerContext inferenceRescorerContext = inferenceRescorerBuilder.innerBuildContext(randomIntBetween(1, 30), context); + LearnToRankRescorerContext rescorerContext = rescorerBuilder.innerBuildContext(randomIntBetween(1, 30), context); IllegalStateException e = expectThrows( IllegalStateException.class, - () -> inferenceRescorerContext.rescorer() + () -> rescorerContext.rescorer() .rescore( new TopDocs(new TotalHits(10, TotalHits.Relation.EQUAL_TO), new ScoreDoc[10]), mock(IndexSearcher.class), - inferenceRescorerContext + rescorerContext ) ); assertEquals("local model reference is null, missing rewriteAndFetch before rescore phase?", e.getMessage()); @@ -106,21 +118,28 @@ public void testMustRewrite() { public void testRewriteOnCoordinator() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(GOOD_MODEL, () -> testModelLoader); - inferenceRescorerBuilder.windowSize(4); + ScriptService scriptService = mock(ScriptService.class); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + GOOD_MODEL, + null, + () -> testModelLoader, + () -> scriptService + ); + rescorerBuilder.windowSize(4); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), randomIntBetween(1_500_000, Integer.MAX_VALUE) ); - InferenceRescorerBuilder rewritten = rewriteAndFetch(inferenceRescorerBuilder, context); - assertThat(rewritten.getInferenceConfig(), not(nullValue())); - assertThat(rewritten.getInferenceConfig().getNumTopFeatureImportanceValues(), equalTo(2)); + LearnToRankRescorerBuilder rewritten = rewriteAndFetch(rescorerBuilder, context); + assertThat(rewritten.learnToRankConfigSupplier().get(), not(nullValue())); + assertThat(rewritten.learnToRankConfigSupplier().get().getNumTopFeatureImportanceValues(), equalTo(2)); assertThat( - "all", + "feature_1", is( in( - rewritten.getInferenceConfig() + rewritten.learnToRankConfigSupplier() + .get() .getFeatureExtractorBuilders() .stream() .map(LearnToRankFeatureExtractorBuilder::featureName) @@ -133,52 +152,59 @@ public void testRewriteOnCoordinator() throws IOException { public void testRewriteOnCoordinatorWithBadModel() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(BAD_MODEL, () -> testModelLoader); + ScriptService scriptService = mock(ScriptService.class); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + BAD_MODEL, + null, + () -> testModelLoader, + () -> scriptService + ); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), randomIntBetween(1_500_000, Integer.MAX_VALUE) ); - ElasticsearchStatusException ex = expectThrows( - ElasticsearchStatusException.class, - () -> rewriteAndFetch(inferenceRescorerBuilder, context) - ); + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, () -> rewriteAndFetch(rescorerBuilder, context)); assertThat(ex.status(), equalTo(RestStatus.BAD_REQUEST)); } public void testRewriteOnCoordinatorWithMissingModel() { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("missing_model", () -> testModelLoader); + ScriptService scriptService = mock(ScriptService.class); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + "missing_model", + null, + () -> testModelLoader, + () -> scriptService + ); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), randomIntBetween(1_500_000, Integer.MAX_VALUE) ); - expectThrows(ResourceNotFoundException.class, () -> rewriteAndFetch(inferenceRescorerBuilder, context)); + expectThrows(ResourceNotFoundException.class, () -> rewriteAndFetch(rescorerBuilder, context)); } public void testSearchRewrite() throws IOException { - TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( - GOOD_MODEL, - LearnToRankConfigTests.randomLearnToRankConfig(), - () -> testModelLoader + LocalModel localModel = localModel(); + when(localModel.getModelId()).thenReturn(GOOD_MODEL); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + () -> LearnToRankConfigTests.randomLearnToRankConfig(), + () -> localModel ); - QueryRewriteContext context = createSearchExecutionContext(); - InferenceRescorerBuilder rewritten = (InferenceRescorerBuilder) Rewriteable.rewrite(inferenceRescorerBuilder, context, true); - assertThat(rewritten.modelLoadingServiceSupplier(), is(notNullValue())); - inferenceRescorerBuilder = new InferenceRescorerBuilder(GOOD_MODEL, LearnToRankConfigTests.randomLearnToRankConfig(), localModel()); + QueryRewriteContext context = createSearchExecutionContext(); + LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) Rewriteable.rewrite(rescorerBuilder, context, true); - rewritten = (InferenceRescorerBuilder) Rewriteable.rewrite(inferenceRescorerBuilder, context, true); - assertThat(rewritten.modelLoadingServiceSupplier(), is(nullValue())); - assertThat(rewritten.getInferenceDefinition(), is(notNullValue())); + LearnToRankConfig rewrittenLearnToRankConfig = Rewriteable.rewrite(rewritten.learnToRankConfigSupplier().get(), context); + assertThat(rewritten.localModelSupplier().get(), is(localModel)); + assertThat(rewritten.learnToRankConfigSupplier().get(), is(rewrittenLearnToRankConfig)); } - protected InferenceRescorerBuilder rewriteAndFetch(RescorerBuilder builder, QueryRewriteContext context) { - PlainActionFuture> future = new PlainActionFuture<>(); + protected LearnToRankRescorerBuilder rewriteAndFetch(RescorerBuilder builder, QueryRewriteContext context) { + PlainActionFuture> future = new PlainActionFuture<>(); Rewriteable.rewriteAndFetch(builder, context, future); - return (InferenceRescorerBuilder) future.actionGet(); + return (LearnToRankRescorerBuilder) future.actionGet(); } @Override @@ -206,31 +232,31 @@ protected Object simulateMethod(Method method, Object[] args) { public void testRewriteOnShard() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( GOOD_MODEL, - (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(), - () -> testModelLoader + () -> testModelLoader, + () -> (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig() ); SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); - InferenceRescorerBuilder rewritten = (InferenceRescorerBuilder) inferenceRescorerBuilder.rewrite(createSearchExecutionContext()); - assertSame(inferenceRescorerBuilder, rewritten); + LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) rescorerBuilder.rewrite(createSearchExecutionContext()); + assertSame(rescorerBuilder, rewritten); assertFalse(searchExecutionContext.hasAsyncActions()); } public void testRewriteAndFetchOnDataNode() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( GOOD_MODEL, - (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(), - () -> testModelLoader + () -> testModelLoader, + () -> LearnToRankConfigTests.randomLearnToRankConfig() ); boolean setWindowSize = randomBoolean(); if (setWindowSize) { - inferenceRescorerBuilder.windowSize(42); + rescorerBuilder.windowSize(42); } DataRewriteContext rewriteContext = dataRewriteContext(); - InferenceRescorerBuilder rewritten = (InferenceRescorerBuilder) inferenceRescorerBuilder.rewrite(rewriteContext); - assertNotSame(inferenceRescorerBuilder, rewritten); + LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) rescorerBuilder.rewrite(rewriteContext); + assertNotSame(rescorerBuilder, rewritten); assertTrue(rewriteContext.hasAsyncActions()); if (setWindowSize) { assertThat(rewritten.windowSize(), equalTo(42)); @@ -242,12 +268,12 @@ public void testBuildContext() throws Exception { List inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME); when(localModel.inputFields()).thenReturn(inputFields); SearchExecutionContext context = createSearchExecutionContext(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( - GOOD_MODEL, - (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(), - localModel + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + () -> LearnToRankConfigTests.randomLearnToRankConfig(), + () -> localModel ); - InferenceRescorerContext rescoreContext = inferenceRescorerBuilder.innerBuildContext(20, context); + + LearnToRankRescorerContext rescoreContext = rescorerBuilder.innerBuildContext(20, context); assertNotNull(rescoreContext); assertThat(rescoreContext.getWindowSize(), equalTo(20)); List featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher()); @@ -262,6 +288,12 @@ private static LocalModel localModel() { return mock(LocalModel.class); } + private static QueryExtractorBuilder queryExtractorBuilder(String featureName) throws IOException { + QueryProvider queryProvider = mock(QueryProvider.class); + when(queryProvider.rewrite(any())).thenReturn(queryProvider); + return new QueryExtractorBuilder(featureName, queryProvider); + } + private static class TestModelLoader extends ModelLoadingService { TestModelLoader() { super( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java similarity index 58% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java index 54bb4a07d6085..5a1cf3cb6e442 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java @@ -12,30 +12,32 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Tuple; import org.elasticsearch.search.SearchModule; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.function.Supplier; import static org.elasticsearch.search.rank.RankBuilder.WINDOW_SIZE_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig; +import static org.mockito.Mockito.mock; -public class InferenceRescorerBuilderSerializationTests extends AbstractBWCSerializationTestCase { +public class LearnToRankRescorerBuilderSerializationTests extends AbstractBWCSerializationTestCase { @Override - protected InferenceRescorerBuilder doParseInstance(XContentParser parser) throws IOException { + protected LearnToRankRescorerBuilder doParseInstance(XContentParser parser) throws IOException { String fieldName = null; - InferenceRescorerBuilder rescorer = null; + LearnToRankRescorerBuilder rescorer = null; Integer windowSize = null; XContentParser.Token token = parser.nextToken(); assert token == XContentParser.Token.START_OBJECT; @@ -49,7 +51,7 @@ protected InferenceRescorerBuilder doParseInstance(XContentParser parser) throws throw new ParsingException(parser.getTokenLocation(), "rescore doesn't support [" + fieldName + "]"); } } else if (token == XContentParser.Token.START_OBJECT) { - rescorer = InferenceRescorerBuilder.fromXContent(parser, null); + rescorer = LearnToRankRescorerBuilder.fromXContent(parser, null, null); } else { throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "] after [" + fieldName + "]"); } @@ -64,32 +66,44 @@ protected InferenceRescorerBuilder doParseInstance(XContentParser parser) throws } @Override - protected Writeable.Reader instanceReader() { - return in -> new InferenceRescorerBuilder(in, null); + protected Writeable.Reader instanceReader() { + return in -> new LearnToRankRescorerBuilder(in, null, null); } @Override - protected InferenceRescorerBuilder createTestInstance() { - InferenceRescorerBuilder builder = randomBoolean() - ? new InferenceRescorerBuilder(randomAlphaOfLength(10), null) - : new InferenceRescorerBuilder( - randomAlphaOfLength(10), - LearnToRankConfigTests.randomLearnToRankConfig(), - (Supplier) null - ); + protected LearnToRankRescorerBuilder createTestInstance() { + LearnToRankConfig learnToRankConfig = randomLearnToRankConfig(); + LearnToRankRescorerBuilder builder = randomBoolean() + ? createXContextTestInstance(null) + : new LearnToRankRescorerBuilder(randomAlphaOfLength(10), null, () -> learnToRankConfig); + if (randomBoolean()) { builder.windowSize(randomIntBetween(1, 10000)); } + return builder; } @Override - protected InferenceRescorerBuilder mutateInstance(InferenceRescorerBuilder instance) throws IOException { - int i = randomInt(3); + protected LearnToRankRescorerBuilder createXContextTestInstance(XContentType xContentType) { + return new LearnToRankRescorerBuilder( + randomAlphaOfLength(10), + randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier())), + null, + null + ); + } + + @Override + protected LearnToRankRescorerBuilder mutateInstance(LearnToRankRescorerBuilder instance) throws IOException { + + int i = randomInt(4); return switch (i) { case 0 -> { - InferenceRescorerBuilder builder = new InferenceRescorerBuilder( - randomValueOtherThan(instance.getModelId(), () -> randomAlphaOfLength(10)), + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(10)), + instance.params(), + null, null ); if (instance.windowSize() != null) { @@ -97,22 +111,35 @@ protected InferenceRescorerBuilder mutateInstance(InferenceRescorerBuilder insta } yield builder; } - case 1 -> new InferenceRescorerBuilder(instance.getModelId(), null).windowSize( + case 1 -> new LearnToRankRescorerBuilder(instance.modelId(), instance.params(), null, null).windowSize( randomValueOtherThan(instance.windowSize(), () -> randomIntBetween(1, 10000)) ); case 2 -> { - InferenceRescorerBuilder builder = new InferenceRescorerBuilder(instance.getModelId(), null); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + instance.modelId(), + randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier())), + null, + null + ); if (instance.windowSize() != null) { - builder.windowSize(instance.windowSize()); + builder.windowSize(instance.windowSize() + 1); } yield builder; } case 3 -> { - InferenceRescorerBuilder builder = new InferenceRescorerBuilder( - instance.getModelId(), - randomValueOtherThan(instance.getInferenceConfig(), LearnToRankConfigTests::randomLearnToRankConfig), - (Supplier) null - ); + LearnToRankConfig learnToRankConfig = instance.learnToRankConfigSupplier() != null + ? randomValueOtherThan(instance.learnToRankConfigSupplier().get(), () -> randomLearnToRankConfig()) + : randomLearnToRankConfig(); + + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.modelId(), null, () -> learnToRankConfig); + if (instance.windowSize() != null) { + builder.windowSize(instance.windowSize()); + } + yield builder; + } + case 4 -> { + LocalModel localModel = mock(LocalModel.class); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.learnToRankConfigSupplier(), () -> localModel); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); } @@ -123,20 +150,10 @@ protected InferenceRescorerBuilder mutateInstance(InferenceRescorerBuilder insta } @Override - protected InferenceRescorerBuilder mutateInstanceForVersion(InferenceRescorerBuilder instance, TransportVersion version) { + protected LearnToRankRescorerBuilder mutateInstanceForVersion(LearnToRankRescorerBuilder instance, TransportVersion version) { return instance; } - public void testIncorrectInferenceConfigType() { - InferenceRescorerBuilder.Builder builder = new InferenceRescorerBuilder.Builder(); - expectThrows( - IllegalArgumentException.class, - () -> builder.setInferenceConfig(ClassificationConfigTests.randomClassificationConfig()) - ); - // Should not throw - builder.setInferenceConfig(LearnToRankConfigTests.randomLearnToRankConfig()); - } - @Override protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>();