diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java index 6f78077e16288..b878aad27a42e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; @@ -58,7 +59,7 @@ public List getNamedWriteables() { List namedWriteables = new ArrayList<>(); // Inference config namedWriteables.add( - new NamedWriteableRegistry.Entry(LearnToRankConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new) + new NamedWriteableRegistry.Entry(InferenceConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new) ); // LTR Extractors namedWriteables.add( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java index 4f50751abb972..dfd96dbb27bbc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java @@ -89,7 +89,11 @@ public LearnToRankRescorerBuilder(StreamInput input, LearnToRankService learnToR super(input); this.modelId = input.readString(); this.params = input.readMap(); - this.learnToRankConfig = input.readOptionalNamedWriteable(LearnToRankConfig.class); + if (input.readBoolean()) { + this.learnToRankConfig = new LearnToRankConfig(input); + } else { + this.learnToRankConfig = null; + } this.learnToRankService = learnToRankService; this.localModel = null; @@ -254,7 +258,13 @@ protected void doWriteTo(StreamOutput out) throws IOException { assert localModel == null || rescoreOccurred : "Unnecessarily populated local model object"; out.writeString(modelId); out.writeGenericMap(params); - out.writeOptionalNamedWriteable(learnToRankConfig); + + if (learnToRankConfig != null) { + out.writeBoolean(true); + learnToRankConfig.writeTo(out); + } else { + out.writeBoolean(false); + } } @Override