diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java index dbfae12413632..b878aad27a42e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java @@ -10,9 +10,7 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; @@ -46,14 +44,6 @@ public List getNamedXContentParsers() { LearnToRankConfig::fromXContentStrict ) ); - // Inference Config Update - namedXContent.add( - new NamedXContentRegistry.Entry( - InferenceConfigUpdate.class, - LearnToRankConfigUpdate.NAME, - LearnToRankConfigUpdate::fromXContentStrict - ) - ); // LTR extractors namedXContent.add( new NamedXContentRegistry.Entry( @@ -71,14 +61,6 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new) ); - // Inference config update - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceConfigUpdate.class, - LearnToRankConfigUpdate.NAME.getPreferredName(), - LearnToRankConfigUpdate::new - ) - ); // LTR Extractors namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java deleted file mode 100644 index b4241f1704520..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.core.ml.inference.trainedmodel; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.Rewriteable; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; -import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig.FEATURE_EXTRACTORS; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES; - -public class LearnToRankConfigUpdate implements InferenceConfigUpdate, NamedXContentObject, Rewriteable { - - public static final ParseField NAME = LearnToRankConfig.NAME; - - public static LearnToRankConfigUpdate EMPTY_PARAMS = new LearnToRankConfigUpdate(null, null); - - public static LearnToRankConfigUpdate fromConfig(LearnToRankConfig config) { - return new LearnToRankConfigUpdate(config.getNumTopFeatureImportanceValues(), config.getFeatureExtractorBuilders()); - } - - private static final ObjectParser STRICT_PARSER = createParser(false); - - private static ObjectParser createParser(boolean lenient) { - ObjectParser parser = new ObjectParser<>( - NAME.getPreferredName(), - lenient, - LearnToRankConfigUpdate.Builder::new - ); - parser.declareInt(LearnToRankConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); - parser.declareNamedObjects( - LearnToRankConfigUpdate.Builder::setFeatureExtractorBuilders, - (p, c, n) -> p.namedObject(LearnToRankFeatureExtractorBuilder.class, n, false), - b -> {}, - FEATURE_EXTRACTORS - ); - return parser; - } - - public static LearnToRankConfigUpdate fromXContentStrict(XContentParser parser) { - return STRICT_PARSER.apply(parser, null).build(); - } - - private final Integer numTopFeatureImportanceValues; - private final List featureExtractorBuilderList; - - public LearnToRankConfigUpdate( - Integer numTopFeatureImportanceValues, - List featureExtractorBuilders - ) { - if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { - throw new IllegalArgumentException( - "[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + "] must be greater than or equal to 0" - ); - } - if (featureExtractorBuilders != null) { - Set featureNames = featureExtractorBuilders.stream() - .map(LearnToRankFeatureExtractorBuilder::featureName) - .collect(Collectors.toSet()); - if (featureNames.size() < featureExtractorBuilders.size()) { - throw new IllegalArgumentException( - "[" + FEATURE_EXTRACTORS.getPreferredName() + "] contains duplicate [feature_name] values" - ); - } - } - this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; - this.featureExtractorBuilderList = featureExtractorBuilders == null ? List.of() : featureExtractorBuilders; - } - - public LearnToRankConfigUpdate(StreamInput in) throws IOException { - this.numTopFeatureImportanceValues = in.readOptionalVInt(); - this.featureExtractorBuilderList = in.readNamedWriteableCollectionAsList(LearnToRankFeatureExtractorBuilder.class); - } - - public Integer getNumTopFeatureImportanceValues() { - return numTopFeatureImportanceValues; - } - - @Override - public String getResultsField() { - return DEFAULT_RESULTS_FIELD; - } - - @Override - public InferenceConfigUpdate.Builder, ? extends InferenceConfigUpdate> newBuilder() { - return new Builder().setNumTopFeatureImportanceValues(numTopFeatureImportanceValues); - } - - @Override - public String getWriteableName() { - return NAME.getPreferredName(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalVInt(numTopFeatureImportanceValues); - out.writeNamedWriteableCollection(featureExtractorBuilderList); - } - - @Override - public String getName() { - return NAME.getPreferredName(); - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return LearnToRankConfig.MIN_SUPPORTED_TRANSPORT_VERSION; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (numTopFeatureImportanceValues != null) { - builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); - } - if (featureExtractorBuilderList.isEmpty() == false) { - NamedXContentObjectHelper.writeNamedObjects( - builder, - params, - true, - FEATURE_EXTRACTORS.getPreferredName(), - featureExtractorBuilderList - ); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - LearnToRankConfigUpdate that = (LearnToRankConfigUpdate) o; - return Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) - && Objects.equals(this.featureExtractorBuilderList, that.featureExtractorBuilderList); - } - - @Override - public int hashCode() { - return Objects.hash(numTopFeatureImportanceValues, featureExtractorBuilderList); - } - - @Override - public LearnToRankConfig apply(InferenceConfig originalConfig) { - if (originalConfig instanceof LearnToRankConfig == false) { - throw ExceptionsHelper.badRequestException( - "Inference config of type [{}] can not be updated with a inference request of type [{}]", - originalConfig.getName(), - getName() - ); - } - - LearnToRankConfig ltrConfig = (LearnToRankConfig) originalConfig; - if (isNoop(ltrConfig)) { - return ltrConfig; - } - LearnToRankConfig.Builder builder = new LearnToRankConfig.Builder(ltrConfig); - if (numTopFeatureImportanceValues != null) { - builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues); - } - if (featureExtractorBuilderList.isEmpty() == false) { - Map existingExtractors = ltrConfig.getFeatureExtractorBuilders() - .stream() - .collect(Collectors.toMap(LearnToRankFeatureExtractorBuilder::featureName, f -> f)); - featureExtractorBuilderList.forEach(f -> existingExtractors.put(f.featureName(), f)); - builder.setLearnToRankFeatureExtractorBuilders(new ArrayList<>(existingExtractors.values())); - } - return builder.build(); - } - - @Override - public boolean isSupported(InferenceConfig inferenceConfig) { - return inferenceConfig instanceof LearnToRankConfig; - } - - boolean isNoop(LearnToRankConfig originalConfig) { - return (numTopFeatureImportanceValues == null || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues) - && (featureExtractorBuilderList.isEmpty() - || Objects.equals(originalConfig.getFeatureExtractorBuilders(), featureExtractorBuilderList)); - } - - @Override - public LearnToRankConfigUpdate rewrite(QueryRewriteContext ctx) throws IOException { - if (featureExtractorBuilderList.isEmpty()) { - return this; - } - List rewrittenBuilders = new ArrayList<>(featureExtractorBuilderList.size()); - boolean rewritten = false; - for (LearnToRankFeatureExtractorBuilder extractorBuilder : featureExtractorBuilderList) { - LearnToRankFeatureExtractorBuilder rewrittenExtractor = Rewriteable.rewrite(extractorBuilder, ctx); - rewritten |= (rewrittenExtractor != extractorBuilder); - rewrittenBuilders.add(rewrittenExtractor); - } - if (rewritten) { - return new LearnToRankConfigUpdate(getNumTopFeatureImportanceValues(), rewrittenBuilders); - } - return this; - } - - public static class Builder implements InferenceConfigUpdate.Builder { - private Integer numTopFeatureImportanceValues; - private List featureExtractorBuilderList; - - @Override - public Builder setResultsField(String resultsField) { - assert false : "results field should never be set in ltr config"; - return this; - } - - public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { - this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; - return this; - } - - public Builder setFeatureExtractorBuilders(List featureExtractorBuilderList) { - this.featureExtractorBuilderList = featureExtractorBuilderList; - return this; - } - - @Override - public LearnToRankConfigUpdate build() { - return new LearnToRankConfigUpdate(numTopFeatureImportanceValues, featureExtractorBuilderList); - } - } -} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java deleted file mode 100644 index 30befc767300b..0000000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.core.ml.inference.trainedmodel; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.SearchModule; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests; -import org.elasticsearch.xpack.core.ml.utils.QueryProvider; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.in; -import static org.hamcrest.Matchers.is; - -public class LearnToRankConfigUpdateTests extends AbstractBWCSerializationTestCase { - - public static LearnToRankConfigUpdate randomLearnToRankConfigUpdate() { - return new LearnToRankConfigUpdate( - randomBoolean() ? null : randomIntBetween(0, 10), - randomBoolean() - ? null - : Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList()) - ); - } - - public void testApply() throws IOException { - LearnToRankConfig originalConfig = randomLearnToRankConfig(); - assertThat(originalConfig, equalTo(LearnToRankConfigUpdate.EMPTY_PARAMS.apply(originalConfig))); - assertThat( - new LearnToRankConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(5).build(), - equalTo(new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(5).build().apply(originalConfig)) - ); - assertThat( - new LearnToRankConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(1).build(), - equalTo(new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(1).build().apply(originalConfig)) - ); - - LearnToRankFeatureExtractorBuilder extractorBuilder = new QueryExtractorBuilder( - "foo", - QueryProvider.fromParsedQuery(QueryBuilders.termQuery("foo", "bar")) - ); - LearnToRankFeatureExtractorBuilder extractorBuilder2 = new QueryExtractorBuilder( - "bar", - QueryProvider.fromParsedQuery(QueryBuilders.termQuery("foo", "bar")) - ); - - LearnToRankConfig config = new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(1) - .setFeatureExtractorBuilders(List.of(extractorBuilder2, extractorBuilder)) - .build() - .apply(originalConfig); - assertThat(config.getNumTopFeatureImportanceValues(), equalTo(1)); - assertThat(extractorBuilder2, is(in(config.getFeatureExtractorBuilders()))); - assertThat(extractorBuilder, is(in(config.getFeatureExtractorBuilders()))); - } - - @Override - protected LearnToRankConfigUpdate createTestInstance() { - return randomLearnToRankConfigUpdate(); - } - - @Override - protected LearnToRankConfigUpdate mutateInstance(LearnToRankConfigUpdate instance) { - return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929 - } - - @Override - protected Writeable.Reader instanceReader() { - return LearnToRankConfigUpdate::new; - } - - @Override - protected LearnToRankConfigUpdate doParseInstance(XContentParser parser) throws IOException { - return LearnToRankConfigUpdate.fromXContentStrict(parser); - } - - @Override - protected LearnToRankConfigUpdate mutateInstanceForVersion(LearnToRankConfigUpdate instance, TransportVersion version) { - return instance; - } - - @Override - protected NamedXContentRegistry xContentRegistry() { - List namedXContent = new ArrayList<>(); - namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); - return new NamedXContentRegistry(namedXContent); - } - - @Override - protected NamedWriteableRegistry writableRegistry() { - List namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); - return new NamedWriteableRegistry(namedWriteables); - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return writableRegistry(); - } -} diff --git a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java index 8dec8dcdb020e..f3ac67f338eee 100644 --- a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java +++ b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java @@ -29,15 +29,15 @@ public class MlRescorerIT extends ESRestTestCase { @Before public void setupModelAndData() throws IOException { putRegressionModel(MODEL_ID, """ - { + { "description": "super complex model for tests", - "input": {"field_names": ["cost", "product"]}, + "input": { "field_names": ["cost", "product"] }, "inference_config": { "learn_to_rank": { "feature_extractors": [{ "query_extractor": { "feature_name": "two", - "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}} + "query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 2.0;" } } } } }] } @@ -177,8 +177,9 @@ public void setupModelAndData() throws IOException { }"""); createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", randomIntBetween(1, 3)).build(), """ "properties":{ - "product":{"type": "keyword"}, - "cost":{"type": "integer"}}"""); + "product":{ "type": "keyword" }, + "cost":{ "type": "integer" } + }"""); indexData("{ \"product\": \"TV\", \"cost\": 300 }"); indexData("{ \"product\": \"TV\", \"cost\": 400 }"); indexData("{ \"product\": \"VCR\", \"cost\": 150 }"); @@ -191,16 +192,15 @@ public void setupModelAndData() throws IOException { public void testLtrSimple() throws Exception { Response searchResponse = search(""" { - "query": { - "match": { "product": { "query": "TV"}} - }, - "rescore": { + "query": { + "match": { "product": { "query": "TV" } } + }, + "rescore": { "window_size": 10, "inference": { "model_id": "basic-ltr-model" - } + } } - }"""); Map response = responseAsMap(searchResponse); @@ -209,23 +209,25 @@ public void testLtrSimple() throws Exception { @SuppressWarnings("unchecked") public void testLtrSimpleDFS() throws Exception { + skipTestUntilParametersAreImplemented(); + Response searchResponse = searchDfs(""" { - "query": { - "match": { "product": { "query": "TV"}} - }, - "rescore": { + "query": { + "match": { "product": { "query": "TV" } } + }, + "rescore": { "window_size": 10, "inference": { "model_id": "basic-ltr-model", "inference_config": { - "learn_to_rank": { - "feature_extractors":[ - {"query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "TV"}}}} - ] - } + "learn_to_rank": { + "feature_extractors":[ + { "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "TV" } } } } + ] + } } - } + } } }"""); @@ -236,19 +238,18 @@ public void testLtrSimpleDFS() throws Exception { searchResponse = searchDfs(""" { "rescore": { - "window_size": 10, - "inference": { - "model_id": "basic-ltr-model", + "window_size": 10, + "inference": { + "model_id": "basic-ltr-model", "inference_config": { - "learn_to_rank": { - "feature_extractors":[ - {"query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "TV"}}}} - ] + "learn_to_rank": { + "feature_extractors":[ + { "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "TV" } } } } + ] } - } } + } } - }"""); response = responseAsMap(searchResponse); @@ -262,16 +263,16 @@ public void testLtrSimpleDFS() throws Exception { @SuppressWarnings("unchecked") public void testLtrSimpleEmpty() throws Exception { Response searchResponse = search(""" - { "query": { - "term": { "product": "computer"} - }, - "rescore": { + { + "query": { + "term": { "product": "computer" } + }, + "rescore": { "window_size": 10, "inference": { "model_id": "basic-ltr-model" - } + } } - }"""); Map response = responseAsMap(searchResponse); @@ -281,16 +282,16 @@ public void testLtrSimpleEmpty() throws Exception { @SuppressWarnings("unchecked") public void testLtrEmptyDFS() throws Exception { Response searchResponse = searchDfs(""" - { "query": { - "match": { "product": { "query": "computer"}} - }, - "rescore": { + { + "query": { + "match": { "product": { "query": "computer"} } + }, + "rescore": { "window_size": 10, "inference": { "model_id": "basic-ltr-model" - } + } } - }"""); Map response = responseAsMap(searchResponse); @@ -300,30 +301,31 @@ public void testLtrEmptyDFS() throws Exception { @SuppressWarnings("unchecked") public void testLtrCanMatch() throws Exception { Response searchResponse = searchCanMatch(""" - { "query": { - "match": { "product": { "query": "TV"}} - }, - "rescore": { + { + "query": { + "match": { "product": { "query": "TV" } } + }, + "rescore": { "window_size": 10, "inference": { "model_id": "basic-ltr-model" - } + } } - }""", false); Map response = responseAsMap(searchResponse); assertThat(response.toString(), (List) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0)); searchResponse = searchCanMatch(""" - { "query": { - "match": { "product": { "query": "TV"}} - }, - "rescore": { + { + "query": { + "match": { "product": { "query": "TV" } } + }, + "rescore": { "window_size": 10, "inference": { "model_id": "basic-ltr-model" - } + } } }""", true); @@ -364,4 +366,7 @@ private void putRegressionModel(String modelId, String body) throws IOException assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); } + private void skipTestUntilParametersAreImplemented() { + // throw new AssumptionViolatedException("Skip the test until parameters are implemented"); + } } diff --git a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java index f51d2915d903e..600e656f6cf48 100644 --- a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java +++ b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java @@ -26,17 +26,19 @@ public class InferenceRescorerIT extends InferenceTestCase { @Before public void setupModelAndData() throws IOException { putRegressionModel(MODEL_ID, """ - { + { "description": "super complex model for tests", "input": {"field_names": ["cost", "product"]}, "inference_config": { "learn_to_rank": { - "feature_extractors": [{ - "query_extractor": { - "feature_name": "two", - "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}} + "feature_extractors": [ + { + "query_extractor": { + "feature_name": "two", + "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}} + } } - }] + ] } }, "definition": { @@ -50,132 +52,131 @@ public void setupModelAndData() throws IOException { } } }], - "trained_model": { - "ensemble": { - "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"], - "target_type": "regression", - "trained_models": [ - { - "tree": { - "feature_names": [ - "cost" - ], - "tree_structure": [ - { - "node_index": 0, - "split_feature": 0, - "split_gain": 12, - "threshold": 400, - "decision_type": "lte", - "default_left": true, - "left_child": 1, - "right_child": 2 - }, - { - "node_index": 1, - "leaf_value": 5.0 - }, - { - "node_index": 2, - "leaf_value": 2.0 - } - ], - "target_type": "regression" - } - }, - { - "tree": { - "feature_names": [ - "type_tv" - ], - "tree_structure": [ - { - "node_index": 0, - "split_feature": 0, - "split_gain": 12, - "threshold": 1, - "decision_type": "lt", - "default_left": true, - "left_child": 1, - "right_child": 2 - }, - { - "node_index": 1, - "leaf_value": 1.0 - }, - { - "node_index": 2, - "leaf_value": 12.0 - } - ], - "target_type": "regression" - } - }, - { - "tree": { - "feature_names": [ - "two" - ], - "tree_structure": [ - { - "node_index": 0, - "split_feature": 0, - "split_gain": 12, - "threshold": 1, - "decision_type": "lt", - "default_left": true, - "left_child": 1, - "right_child": 2 - }, - { - "node_index": 1, - "leaf_value": 1.0 - }, - { - "node_index": 2, - "leaf_value": 2.0 - } - ], - "target_type": "regression" - } - }, - { - "tree": { - "feature_names": [ - "product_bm25" - ], - "tree_structure": [ - { - "node_index": 0, - "split_feature": 0, - "split_gain": 12, - "threshold": 1, - "decision_type": "lt", - "default_left": true, - "left_child": 1, - "right_child": 2 - }, - { - "node_index": 1, - "leaf_value": 1.0 - }, - { - "node_index": 2, - "leaf_value": 4.0 - } - ], - "target_type": "regression" - } - } - ] - } - } - } - }"""); + "trained_model": { + "ensemble": { + "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"], + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["cost"], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 400, + "decision_type": "lte", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 5.0 + }, + { + "node_index": 2, + "leaf_value": 2.0 + } + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": [ + "type_tv" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 1, + "decision_type": "lt", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 1.0 + }, + { + "node_index": 2, + "leaf_value": 12.0 + } + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": [ + "two" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 1, + "decision_type": "lt", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 1.0 + }, + { + "node_index": 2, + "leaf_value": 2.0 + } + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": [ + "product_bm25" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 1, + "decision_type": "lt", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 1.0 + }, + { + "node_index": 2, + "leaf_value": 4.0 + } + ], + "target_type": "regression" + } + } + ] + } + } + } + }"""); createIndex(INDEX_NAME, Settings.EMPTY, """ "properties":{ - "product":{"type": "keyword"}, - "cost":{"type": "integer"}}"""); + "product":{"type": "keyword"}, + "cost":{"type": "integer"} + }"""); indexData("{ \"product\": \"TV\", \"cost\": 300}"); indexData("{ \"product\": \"TV\", \"cost\": 400}"); indexData("{ \"product\": \"TV\", \"cost\": 600}"); @@ -190,6 +191,8 @@ public void setupModelAndData() throws IOException { @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98372") public void testInferenceRescore() throws Exception { + skipTestUntilParametersAreImplemented(); + Request request = new Request("GET", "store/_search?size=3&error_trace"); request.setJsonEntity(""" { @@ -201,16 +204,16 @@ public void testInferenceRescore() throws Exception { assertHitScores(client().performRequest(request), List.of(20.0, 20.0, 17.0)); request.setJsonEntity(""" { - "query": {"term": {"product": "Laptop"}}, + "query": { "term": { "product": "Laptop" } }, "rescore": { "window_size": 10, "inference": { "model_id": "ltr-model", "inference_config": { "learn_to_rank": { - "feature_extractors":[{ - "query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "Laptop"}}} - }] + "feature_extractors":[ + { "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "Laptop"} } } } + ] } } } @@ -219,7 +222,7 @@ public void testInferenceRescore() throws Exception { assertHitScores(client().performRequest(request), List.of(12.0, 12.0, 9.0)); request.setJsonEntity(""" { - "query": {"term": {"product": "Laptop"}}, + "query": {"term": { "product": "Laptop" } }, "rescore": { "window_size": 10, "inference": { "model_id": "ltr-model"} @@ -247,20 +250,20 @@ public void testInferenceRescorerWithChainedRescorers() throws IOException { request.setJsonEntity(""" { "rescore": [ - { - "window_size": 4, - "query": { "rescore_query":{ "script_score": {"query": {"match_all": {}}, "script": {"source": "return 4"}}}} - }, - { - "window_size": 3, - "inference": { "model_id": "ltr-model" } - }, - { - "window_size": 2, - "query": { "rescore_query": { "script_score": {"query": {"match_all": {}}, "script": {"source": "return 20"}}}} - } + { + "window_size": 4, + "query": { "rescore_query" : { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 4" } } } } + }, + { + "window_size": 3, + "inference": { "model_id": "ltr-model" } + }, + { + "window_size": 2, + "query": { "rescore_query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 20"} } } } + } ] - }"""); + }"""); assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 17.0, 5.0, 1.0)); } @@ -274,4 +277,8 @@ private void indexData(String data) throws IOException { private static void assertHitScores(Response response, List expectedScores) throws IOException { assertThat((List) XContentMapValues.extractValue("hits.hits._score", responseAsMap(response)), equalTo(expectedScores)); } + + private void skipTestUntilParametersAreImplemented() { + // throw new AssumptionViolatedException("Skip the test until parameters are implemented"); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java index 8d450a43722ed..5b110e2b42da7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java @@ -25,9 +25,7 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -45,16 +43,10 @@ public class InferenceRescorerBuilder extends RescorerBuilder PARSER = new ObjectParser<>(NAME, false, Builder::new); static { PARSER.declareString(Builder::setModelId, MODEL); - PARSER.declareNamedObject( - Builder::setInferenceConfigUpdate, - (p, c, name) -> p.namedObject(InferenceConfigUpdate.class, name, false), - INFERENCE_CONFIG - ); PARSER.declareNamedObject( Builder::setInferenceConfig, (p, c, name) -> p.namedObject(StrictlyParsedInferenceConfig.class, name, false), @@ -67,7 +59,6 @@ public static InferenceRescorerBuilder fromXContent(XContentParser parser, Suppl } private final String modelId; - private final LearnToRankConfigUpdate inferenceConfigUpdate; private final LearnToRankConfig inferenceConfig; private final LocalModel inferenceDefinition; private final Supplier inferenceDefinitionSupplier; @@ -75,13 +66,8 @@ public static InferenceRescorerBuilder fromXContent(XContentParser parser, Suppl private final Supplier inferenceConfigSupplier; private boolean rescoreOccurred; - public InferenceRescorerBuilder( - String modelId, - LearnToRankConfigUpdate inferenceConfigUpdate, - Supplier modelLoadingServiceSupplier - ) { + public InferenceRescorerBuilder(String modelId, Supplier modelLoadingServiceSupplier) { this.modelId = Objects.requireNonNull(modelId); - this.inferenceConfigUpdate = inferenceConfigUpdate; this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; this.inferenceDefinition = null; this.inferenceDefinitionSupplier = null; @@ -91,7 +77,6 @@ public InferenceRescorerBuilder( InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, Supplier modelLoadingServiceSupplier) { this.modelId = Objects.requireNonNull(modelId); - this.inferenceConfigUpdate = null; this.inferenceDefinition = null; this.inferenceDefinitionSupplier = null; this.inferenceConfigSupplier = null; @@ -101,12 +86,10 @@ public InferenceRescorerBuilder( private InferenceRescorerBuilder( String modelId, - LearnToRankConfigUpdate update, Supplier modelLoadingServiceSupplier, Supplier inferenceConfigSupplier ) { this.modelId = Objects.requireNonNull(modelId); - this.inferenceConfigUpdate = update; this.inferenceDefinition = null; this.inferenceDefinitionSupplier = null; this.inferenceConfigSupplier = inferenceConfigSupplier; @@ -121,7 +104,6 @@ private InferenceRescorerBuilder( Supplier inferenceDefinitionSupplier ) { this.modelId = modelId; - this.inferenceConfigUpdate = null; this.inferenceDefinition = null; this.inferenceDefinitionSupplier = inferenceDefinitionSupplier; this.modelLoadingServiceSupplier = modelLoadingServiceSupplier; @@ -131,7 +113,6 @@ private InferenceRescorerBuilder( InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, LocalModel inferenceDefinition) { this.modelId = modelId; - this.inferenceConfigUpdate = null; this.inferenceDefinition = inferenceDefinition; this.inferenceDefinitionSupplier = null; this.modelLoadingServiceSupplier = null; @@ -142,7 +123,6 @@ private InferenceRescorerBuilder( public InferenceRescorerBuilder(StreamInput input, Supplier modelLoadingServiceSupplier) throws IOException { super(input); this.modelId = input.readString(); - this.inferenceConfigUpdate = (LearnToRankConfigUpdate) input.readOptionalNamedWriteable(InferenceConfigUpdate.class); this.inferenceDefinitionSupplier = null; this.inferenceConfigSupplier = null; this.inferenceDefinition = null; @@ -210,9 +190,7 @@ private RescorerBuilder doRewrite(QueryRewriteContext ActionListener.wrap(trainedModels -> { TrainedModelConfig config = trainedModels.getResources().results().get(0); if (config.getInferenceConfig() instanceof LearnToRankConfig retrievedInferenceConfig) { - retrievedInferenceConfig = inferenceConfigUpdate == null - ? retrievedInferenceConfig - : inferenceConfigUpdate.apply(retrievedInferenceConfig); + // TODO: apply params instead of an override. for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) { builder.validate(); } @@ -232,12 +210,7 @@ private RescorerBuilder doRewrite(QueryRewriteContext }, l::onFailure) ) ); - InferenceRescorerBuilder builder = new InferenceRescorerBuilder( - modelId, - inferenceConfigUpdate, - modelLoadingServiceSupplier, - configSetOnce::get - ); + InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, modelLoadingServiceSupplier, configSetOnce::get); if (windowSize() != null) { builder.windowSize(windowSize); } @@ -336,7 +309,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { } assert inferenceDefinition == null || rescoreOccurred : "Unnecessarily populated local model object"; out.writeString(modelId); - out.writeOptionalNamedWriteable(inferenceConfigUpdate); out.writeOptionalNamedWriteable(inferenceConfig); } @@ -344,9 +316,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { protected void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); builder.field(MODEL.getPreferredName(), modelId); - if (inferenceConfigUpdate != null) { - NamedXContentObjectHelper.writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfigUpdate); - } if (inferenceConfig != null) { NamedXContentObjectHelper.writeNamedObject(builder, params, INTERNAL_INFERENCE_CONFIG.getPreferredName(), inferenceConfig); } @@ -367,7 +336,6 @@ public boolean equals(Object o) { InferenceRescorerBuilder that = (InferenceRescorerBuilder) o; return Objects.equals(modelId, that.modelId) && Objects.equals(inferenceDefinition, that.inferenceDefinition) - && Objects.equals(inferenceConfigUpdate, that.inferenceConfigUpdate) && Objects.equals(inferenceConfig, that.inferenceConfig) && Objects.equals(inferenceDefinitionSupplier, that.inferenceDefinitionSupplier) && Objects.equals(modelLoadingServiceSupplier, that.modelLoadingServiceSupplier); @@ -378,7 +346,6 @@ public int hashCode() { return Objects.hash( super.hashCode(), modelId, - inferenceConfigUpdate, inferenceConfig, inferenceDefinition, inferenceDefinitionSupplier, @@ -386,10 +353,6 @@ public int hashCode() { ); } - LearnToRankConfigUpdate getInferenceConfigUpdate() { - return inferenceConfigUpdate; - } - // Used in tests Supplier modelLoadingServiceSupplier() { return modelLoadingServiceSupplier; @@ -402,27 +365,12 @@ LocalModel getInferenceDefinition() { static class Builder { private String modelId; - private LearnToRankConfigUpdate inferenceConfigUpdate; private LearnToRankConfig inferenceConfig; public void setModelId(String modelId) { this.modelId = modelId; } - public void setInferenceConfigUpdate(InferenceConfigUpdate inferenceConfigUpdate) { - if (inferenceConfigUpdate instanceof LearnToRankConfigUpdate learnToRankConfigUpdate) { - this.inferenceConfigUpdate = learnToRankConfigUpdate; - return; - } - throw new IllegalArgumentException( - Strings.format( - "[%s] only allows a [%s] object to be configured", - INFERENCE_CONFIG.getPreferredName(), - LearnToRankConfigUpdate.NAME.getPreferredName() - ) - ); - } - void setInferenceConfig(InferenceConfig inferenceConfig) { if (inferenceConfig instanceof LearnToRankConfig learnToRankConfig) { this.inferenceConfig = learnToRankConfig; @@ -431,19 +379,14 @@ void setInferenceConfig(InferenceConfig inferenceConfig) { throw new IllegalArgumentException( Strings.format( "[%s] only allows a [%s] object to be configured", - INFERENCE_CONFIG.getPreferredName(), - LearnToRankConfigUpdate.NAME.getPreferredName() + INTERNAL_INFERENCE_CONFIG.getPreferredName(), + LearnToRankConfig.NAME.getPreferredName() ) ); } InferenceRescorerBuilder build(Supplier modelLoadingServiceSupplier) { - assert inferenceConfig == null || inferenceConfigUpdate == null; - if (inferenceConfig != null) { - return new InferenceRescorerBuilder(modelId, inferenceConfig, modelLoadingServiceSupplier); - } else { - return new InferenceRescorerBuilder(modelId, inferenceConfigUpdate, modelLoadingServiceSupplier); - } + return new InferenceRescorerBuilder(modelId, inferenceConfig, modelLoadingServiceSupplier); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java index b31c425ea1eb4..aec79919c1a50 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.query.CoordinatorRewriteContext; import org.elasticsearch.index.query.DataRewriteContext; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; @@ -39,13 +38,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; @@ -111,11 +106,7 @@ public void testMustRewrite() { public void testRewriteOnCoordinator() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - LearnToRankConfigUpdate ltru = new LearnToRankConfigUpdate( - 2, - List.of(new QueryExtractorBuilder("all", QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()))) - ); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(GOOD_MODEL, ltru, () -> testModelLoader); + InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(GOOD_MODEL, () -> testModelLoader); inferenceRescorerBuilder.windowSize(4); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), @@ -137,17 +128,12 @@ public void testRewriteOnCoordinator() throws IOException { ) ) ); - assertThat(rewritten.getInferenceConfigUpdate(), is(nullValue())); assertThat(rewritten.windowSize(), equalTo(4)); } public void testRewriteOnCoordinatorWithBadModel() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( - BAD_MODEL, - randomBoolean() ? null : LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate(), - () -> testModelLoader - ); + InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(BAD_MODEL, () -> testModelLoader); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), @@ -162,11 +148,7 @@ public void testRewriteOnCoordinatorWithBadModel() throws IOException { public void testRewriteOnCoordinatorWithMissingModel() { TestModelLoader testModelLoader = new TestModelLoader(); - InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder( - "missing_model", - randomBoolean() ? null : LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate(), - () -> testModelLoader - ); + InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("missing_model", () -> testModelLoader); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java index f85d24770f70e..54bb4a07d6085 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java @@ -19,9 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdateTests; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import java.io.IOException; @@ -73,11 +71,7 @@ protected Writeable.Reader instanceReader() { @Override protected InferenceRescorerBuilder createTestInstance() { InferenceRescorerBuilder builder = randomBoolean() - ? new InferenceRescorerBuilder( - randomAlphaOfLength(10), - randomBoolean() ? null : LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate(), - null - ) + ? new InferenceRescorerBuilder(randomAlphaOfLength(10), null) : new InferenceRescorerBuilder( randomAlphaOfLength(10), LearnToRankConfigTests.randomLearnToRankConfig(), @@ -96,7 +90,6 @@ protected InferenceRescorerBuilder mutateInstance(InferenceRescorerBuilder insta case 0 -> { InferenceRescorerBuilder builder = new InferenceRescorerBuilder( randomValueOtherThan(instance.getModelId(), () -> randomAlphaOfLength(10)), - instance.getInferenceConfigUpdate(), null ); if (instance.windowSize() != null) { @@ -104,15 +97,11 @@ protected InferenceRescorerBuilder mutateInstance(InferenceRescorerBuilder insta } yield builder; } - case 1 -> new InferenceRescorerBuilder(instance.getModelId(), instance.getInferenceConfigUpdate(), null).windowSize( + case 1 -> new InferenceRescorerBuilder(instance.getModelId(), null).windowSize( randomValueOtherThan(instance.windowSize(), () -> randomIntBetween(1, 10000)) ); case 2 -> { - InferenceRescorerBuilder builder = new InferenceRescorerBuilder( - instance.getModelId(), - randomValueOtherThan(instance.getInferenceConfigUpdate(), LearnToRankConfigUpdateTests::randomLearnToRankConfigUpdate), - null - ); + InferenceRescorerBuilder builder = new InferenceRescorerBuilder(instance.getModelId(), null); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); } @@ -138,16 +127,6 @@ protected InferenceRescorerBuilder mutateInstanceForVersion(InferenceRescorerBui return instance; } - public void testIncorrectInferenceConfigUpdateType() { - InferenceRescorerBuilder.Builder builder = new InferenceRescorerBuilder.Builder(); - expectThrows( - IllegalArgumentException.class, - () -> builder.setInferenceConfigUpdate(ClassificationConfigUpdateTests.randomClassificationConfigUpdate()) - ); - // Should not throw - builder.setInferenceConfigUpdate(LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate()); - } - public void testIncorrectInferenceConfigType() { InferenceRescorerBuilder.Builder builder = new InferenceRescorerBuilder.Builder(); expectThrows(