diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index ec6729094052a..ee0053ecf0d2f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -73,7 +73,6 @@ import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; -import org.elasticsearch.script.ScriptService; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -358,6 +357,7 @@ import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; import org.elasticsearch.xpack.ml.ltr.InferenceRescorerFeature; import org.elasticsearch.xpack.ml.ltr.LearnToRankRescorerBuilder; +import org.elasticsearch.xpack.ml.ltr.LearnToRankService; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -748,7 +748,7 @@ public void loadExtensions(ExtensionLoader loader) { private final SetOnce mlAutoscalingDeciderService = new SetOnce<>(); private final SetOnce deploymentManager = new SetOnce<>(); private final SetOnce trainedModelAllocationClusterServiceSetOnce = new SetOnce<>(); - private final SetOnce scriptService = new SetOnce<>(); + private final SetOnce learnToRankService = new SetOnce<>(); private final SetOnce machineLearningExtension = new SetOnce<>(); public MachineLearning(Settings settings) { @@ -872,8 +872,8 @@ public List> getRescorers() { return List.of( new RescorerSpec<>( LearnToRankRescorerBuilder.NAME, - in -> new LearnToRankRescorerBuilder(in, modelLoadingService.get(), scriptService.get()), - parser -> LearnToRankRescorerBuilder.fromXContent(parser, modelLoadingService.get(), scriptService.get()) + in -> new LearnToRankRescorerBuilder(in, learnToRankService.get()), + parser -> LearnToRankRescorerBuilder.fromXContent(parser, learnToRankService.get()) ) ); } @@ -897,7 +897,6 @@ public Collection createComponents(PluginServices services) { machineLearningExtension.get().configure(environment.settings()); - this.scriptService.set(services.scriptService()); this.mlUpgradeModeActionFilter.set(new MlUpgradeModeActionFilter(clusterService)); MlIndexTemplateRegistry registry = new MlIndexTemplateRegistry( @@ -1106,6 +1105,8 @@ public Collection createComponents(PluginServices services) { new DeploymentManager(client, xContentRegistry, threadPool, pyTorchProcessFactory, getMaxModelDeploymentsPerNode()) ); + this.learnToRankService.set(new LearnToRankService(modelLoadingService, services.scriptService(), services.xContentRegistry())); + // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager( settings, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java index 30882d8434eab..f0469420b6af1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilder.java @@ -10,40 +10,23 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.script.Script; -import org.elasticsearch.script.ScriptService; -import org.elasticsearch.script.ScriptType; -import org.elasticsearch.script.TemplateScript; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import java.io.IOException; -import java.util.Collections; import java.util.Map; import java.util.Objects; -import java.util.Optional; - -import static org.elasticsearch.script.Script.DEFAULT_TEMPLATE_LANG; public class LearnToRankRescorerBuilder extends RescorerBuilder { @@ -57,73 +40,60 @@ public class LearnToRankRescorerBuilder extends RescorerBuilder p.map(), PARAMS_FIELD); } - public static LearnToRankRescorerBuilder fromXContent( - XContentParser parser, - ModelLoadingService modelLoadingService, - ScriptService scriptService - ) { - return PARSER.apply(parser, null).build(modelLoadingService, scriptService); + public static LearnToRankRescorerBuilder fromXContent(XContentParser parser, LearnToRankService learnToRankService) { + return PARSER.apply(parser, null).build(learnToRankService); } private final String modelId; private final Map params; - private final ScriptService scriptService; - private final ModelLoadingService modelLoadingService; + private final LearnToRankService learnToRankService; private final LocalModel localModel; private final LearnToRankConfig learnToRankConfig; + private boolean rescoreOccurred = false; + private LearnToRankRescorerBuilder() { + this(null, null, null); + } + + LearnToRankRescorerBuilder(String modelId, Map params, LearnToRankService learnToRankService) { + this(modelId, null, params, learnToRankService); + } + LearnToRankRescorerBuilder( String modelId, + LearnToRankConfig learnToRankConfig, Map params, - ModelLoadingService modelLoadingService, - ScriptService scriptService + LearnToRankService learnToRankService ) { this.modelId = modelId; this.params = params; - this.scriptService = scriptService; - this.modelLoadingService = modelLoadingService; - - // Config and model will be set during successive rewrite phases. - this.learnToRankConfig = null; - this.localModel = null; - } - - LearnToRankRescorerBuilder(String modelId, ModelLoadingService modelLoadingService, LearnToRankConfig learnToRankConfig) { - this.modelId = modelId; - this.modelLoadingService = modelLoadingService; this.learnToRankConfig = learnToRankConfig; + this.learnToRankService = learnToRankService; // Local inference model is not loaded yet. Will be done in a later rewrite. this.localModel = null; - - // Templates has been applied already, so we do not need params and script service anymore. - this.params = null; - this.scriptService = null; } - LearnToRankRescorerBuilder(LearnToRankConfig learnToRankConfig, LocalModel localModel) { + LearnToRankRescorerBuilder( + LocalModel localModel, + LearnToRankConfig learnToRankConfig, + Map params, + LearnToRankService learnToRankService + ) { this.modelId = localModel.getModelId(); + this.params = params; this.learnToRankConfig = learnToRankConfig; this.localModel = localModel; - - // Model is loaded already, so we do not need the model loading service anymore. - this.modelLoadingService = null; - - // Template has been applied already, so we do not need params and script service anymore. - this.params = null; - this.scriptService = null; + this.learnToRankService = learnToRankService; } - public LearnToRankRescorerBuilder(StreamInput input, ModelLoadingService modelLoadingService, ScriptService scriptService) - throws IOException { + public LearnToRankRescorerBuilder(StreamInput input, LearnToRankService learnToRankService) throws IOException { super(input); this.modelId = input.readString(); this.params = input.readMap(); this.learnToRankConfig = input.readOptionalNamedWriteable(LearnToRankConfig.class); - - this.modelLoadingService = modelLoadingService; - this.scriptService = scriptService; + this.learnToRankService = learnToRankService; this.localModel = null; } @@ -140,8 +110,8 @@ public LearnToRankConfig learnToRankConfig() { return learnToRankConfig; } - public ModelLoadingService modelLoadingService() { - return modelLoadingService; + public LearnToRankService learnToRankService() { + return learnToRankService; } public LocalModel localModel() { @@ -175,45 +145,28 @@ private RescorerBuilder doCoordinatorNodeRewrite(Que if (rewrittenConfig == learnToRankConfig) { return this; } - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, rewrittenConfig); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, rewrittenConfig, params, learnToRankService); if (windowSize != null) { builder.windowSize(windowSize); } return builder; } + if (learnToRankService == null) { + throw new IllegalStateException("Learn to rank service must be available"); + } + SetOnce configSetOnce = new SetOnce<>(); GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); request.setAllowNoResources(false); ctx.registerAsyncAction( - (c, l) -> ClientHelper.executeAsyncWithOrigin( - c, - ClientHelper.ML_ORIGIN, - GetTrainedModelsAction.INSTANCE, - request, - ActionListener.wrap(trainedModels -> { - TrainedModelConfig config = trainedModels.getResources().results().get(0); - if (config.getInferenceConfig() instanceof LearnToRankConfig retrievedInferenceConfig) { - for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) { - builder.validate(); - } - configSetOnce.set(applyParams(retrievedInferenceConfig, ctx)); - l.onResponse(null); - return; - } - l.onFailure( - ExceptionsHelper.badRequestException( - Messages.getMessage( - Messages.INFERENCE_CONFIG_INCORRECT_TYPE, - Optional.ofNullable(config.getInferenceConfig()).map(InferenceConfig::getName).orElse("null"), - LearnToRankConfig.NAME.getPreferredName() - ) - ) - ); - }, l::onFailure) - ) + (c, l) -> learnToRankService.loadLearnToRankConfig(c, modelId, params, ActionListener.wrap(learnToRankConfig -> { + configSetOnce.set(learnToRankConfig); + l.onResponse(null); + }, l::onFailure)) ); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, null) { + + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder() { @Override public RescorerBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (configSetOnce.get() == null) { @@ -221,7 +174,12 @@ public RescorerBuilder rewrite(QueryRewriteContext c return this; } - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, configSetOnce.get()); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + modelId, + configSetOnce.get(), + params, + learnToRankService + ); if (windowSize() != null) { builder.windowSize(windowSize()); @@ -237,29 +195,6 @@ public RescorerBuilder rewrite(QueryRewriteContext c return builder; } - private LearnToRankConfig applyParams(LearnToRankConfig config, QueryRewriteContext ctx) throws IOException { - if (scriptService.isLangSupported(DEFAULT_TEMPLATE_LANG) == false) { - return config; - } - - if (params == null || params.isEmpty()) { - return config; - } - - try (XContentBuilder configSourceBuilder = XContentBuilder.builder(XContentType.JSON.xContent())) { - String templateSource = BytesReference.bytes(config.toXContent(configSourceBuilder, EMPTY_PARAMS)).utf8ToString(); - if (templateSource.contains("{{") == false) { - return config; - } - Script script = new Script(ScriptType.INLINE, DEFAULT_TEMPLATE_LANG, templateSource, Collections.emptyMap()); - String parsedTemplate = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params).execute(); - - XContentParser parser = XContentType.JSON.xContent().createParser(ctx.getParserConfig(), parsedTemplate); - - return LearnToRankConfig.fromXContentStrict(parser); - } - } - /** * This rewrite phase occurs on the data node when we know we will want to use the model for inference * @param ctx Rewrite context @@ -273,17 +208,18 @@ private RescorerBuilder doDataNodeRewrite(QueryRewri return this; } - if (modelLoadingService == null) { - throw new IllegalStateException("Model loading service must be available"); + if (learnToRankService == null) { + throw new IllegalStateException("Learn to rank service must be available"); } + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfig, ctx); SetOnce localModelSetOnce = new SetOnce<>(); - ctx.registerAsyncAction((c, l) -> modelLoadingService.getModelForLearnToRank(modelId, ActionListener.wrap(lm -> { + ctx.registerAsyncAction((c, l) -> learnToRankService.loadLocalModel(modelId, ActionListener.wrap(lm -> { localModelSetOnce.set(lm); l.onResponse(null); }, l::onFailure))); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, learnToRankConfig) { + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder() { @Override public RescorerBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (localModelSetOnce.get() == null) { @@ -291,7 +227,12 @@ public RescorerBuilder rewrite(QueryRewriteContext c return this; } - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(learnToRankConfig, localModelSetOnce.get()); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + localModelSetOnce.get(), + rewrittenConfig, + params, + learnToRankService + ); if (windowSize() != null) { builder.windowSize(windowSize()); @@ -321,7 +262,7 @@ private RescorerBuilder doSearchRewrite(QueryRewrite if (rewrittenConfig == learnToRankConfig) { return this; } - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(rewrittenConfig, localModel); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(localModel, rewrittenConfig, params, learnToRankService); if (windowSize != null) { builder.windowSize(windowSize); } @@ -374,28 +315,18 @@ public boolean equals(Object o) { && Objects.equals(params, that.params) && Objects.equals(learnToRankConfig, that.learnToRankConfig) && Objects.equals(localModel, that.localModel) - && Objects.equals(modelLoadingService, that.modelLoadingService) - && Objects.equals(scriptService, that.scriptService) + && Objects.equals(learnToRankService, that.learnToRankService) && rescoreOccurred == that.rescoreOccurred; } @Override public int hashCode() { - return Objects.hash( - super.hashCode(), - modelId, - params, - learnToRankConfig, - localModel, - modelLoadingService, - scriptService, - rescoreOccurred - ); + return Objects.hash(super.hashCode(), modelId, params, learnToRankConfig, localModel, learnToRankService, rescoreOccurred); } static class Builder { private String modelId; - private Map params = Collections.emptyMap(); + private Map params = null; public void setModelId(String modelId) { this.modelId = modelId; @@ -405,8 +336,8 @@ public void setParams(Map params) { this.params = params; } - LearnToRankRescorerBuilder build(ModelLoadingService modelLoadingService, ScriptService scriptService) { - return new LearnToRankRescorerBuilder(modelId, params, modelLoadingService, scriptService); + LearnToRankRescorerBuilder build(LearnToRankService learnToRankService) { + return new LearnToRankRescorerBuilder(modelId, params, learnToRankService); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankService.java new file mode 100644 index 0000000000000..95342f113a6fa --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/ltr/LearnToRankService.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.ltr; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.script.TemplateScript; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +import static org.elasticsearch.script.Script.DEFAULT_TEMPLATE_LANG; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; + +public class LearnToRankService { + private final ModelLoadingService modelLoadingService; + private final ScriptService scriptService; + private final XContentParserConfiguration parserConfiguration; + + public LearnToRankService( + ModelLoadingService modelLoadingService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry + ) { + this(modelLoadingService, scriptService, XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry)); + } + + LearnToRankService( + ModelLoadingService modelLoadingService, + ScriptService scriptService, + XContentParserConfiguration parserConfiguration + ) { + this.modelLoadingService = modelLoadingService; + this.scriptService = scriptService; + this.parserConfiguration = parserConfiguration; + } + + public void loadLocalModel(String modelId, ActionListener listener) { + modelLoadingService.getModelForLearnToRank(modelId, listener); + } + + public void loadLearnToRankConfig( + Client client, + String modelId, + Map params, + ActionListener listener + ) { + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); + ClientHelper.executeAsyncWithOrigin( + client, + ClientHelper.ML_ORIGIN, + GetTrainedModelsAction.INSTANCE, + request, + ActionListener.wrap(trainedModels -> { + TrainedModelConfig config = trainedModels.getResources().results().get(0); + if (config.getInferenceConfig() instanceof LearnToRankConfig retrievedInferenceConfig) { + for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) { + builder.validate(); + } + listener.onResponse(applyParams(retrievedInferenceConfig, params)); + return; + } + listener.onFailure( + ExceptionsHelper.badRequestException( + Messages.getMessage( + Messages.INFERENCE_CONFIG_INCORRECT_TYPE, + Optional.ofNullable(config.getInferenceConfig()).map(InferenceConfig::getName).orElse("null"), + LearnToRankConfig.NAME.getPreferredName() + ) + ) + ); + listener.onResponse(null); + }, listener::onFailure) + ); + } + + private LearnToRankConfig applyParams(LearnToRankConfig config, Map params) throws IOException { + if (scriptService.isLangSupported(DEFAULT_TEMPLATE_LANG) == false) { + return config; + } + + if (params == null || params.isEmpty()) { + return config; + } + + try (XContentBuilder configSourceBuilder = XContentBuilder.builder(XContentType.JSON.xContent())) { + String templateSource = BytesReference.bytes(config.toXContent(configSourceBuilder, EMPTY_PARAMS)).utf8ToString(); + if (templateSource.contains("{{") == false) { + return config; + } + Script script = new Script(ScriptType.INLINE, DEFAULT_TEMPLATE_LANG, templateSource, Collections.emptyMap()); + String parsedTemplate = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params).execute(); + + XContentParser parser = XContentType.JSON.xContent().createParser(parserConfiguration, parsedTemplate); + + return LearnToRankConfig.fromXContentStrict(parser); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderRewriteTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderRewriteTests.java index f137d719f5ce6..1fae713300092 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderRewriteTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderRewriteTests.java @@ -14,47 +14,29 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.query.CoordinatorRewriteContext; import org.elasticsearch.index.query.DataRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.test.AbstractBuilderTestCase; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.QueryProvider; -import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; -import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.io.IOException; -import java.lang.reflect.Method; -import java.util.Collections; import java.util.List; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig; +import static org.elasticsearch.xpack.ml.ltr.LearnToRankServiceTests.BAD_MODEL; +import static org.elasticsearch.xpack.ml.ltr.LearnToRankServiceTests.GOOD_MODEL; +import static org.elasticsearch.xpack.ml.ltr.LearnToRankServiceTests.GOOD_MODEL_CONFIG; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -63,41 +45,22 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class LearnToRankRescorerBuilderRewriteTests extends AbstractBuilderTestCase { - private static final String GOOD_MODEL = "modelId"; - private static final String BAD_MODEL = "badModel"; - private static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder() - .setModelId(GOOD_MODEL) - .setInput(new TrainedModelInput(List.of("field1", "field2"))) - .setEstimatedOperations(1) - .setModelSize(2) - .setModelType(TrainedModelType.TREE_ENSEMBLE) - .setInferenceConfig( - new LearnToRankConfig( - 2, - List.of( - new QueryExtractorBuilder("feature_1", new QueryProvider(Collections.emptyMap(), null, null)), - new QueryExtractorBuilder("feature_2", new QueryProvider(Collections.emptyMap(), null, null)) - ) - ) - ) - .build(); - private static final TrainedModelConfig BAD_MODEL_CONFIG = TrainedModelConfig.builder() - .setModelId(BAD_MODEL) - .setInput(new TrainedModelInput(List.of("field1", "field2"))) - .setEstimatedOperations(1) - .setModelSize(2) - .setModelType(TrainedModelType.TREE_ENSEMBLE) - .setInferenceConfig(new RegressionConfig(null, null)) - .build(); - public void testMustRewrite() { - TestModelLoader testModelLoader = new TestModelLoader(); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(GOOD_MODEL, testModelLoader, randomLearnToRankConfig()); + LearnToRankService learnToRankService = learnToRankServiceMock(); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + GOOD_MODEL, + randomLearnToRankConfig(), + null, + learnToRankService + ); SearchExecutionContext context = createSearchExecutionContext(); LearnToRankRescorerContext rescorerContext = rescorerBuilder.innerBuildContext(randomIntBetween(1, 30), context); @@ -114,9 +77,8 @@ public void testMustRewrite() { } public void testRewriteOnCoordinator() throws IOException { - TestModelLoader testModelLoader = new TestModelLoader(); - ScriptService scriptService = mock(ScriptService.class); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(GOOD_MODEL, null, testModelLoader, scriptService); + LearnToRankService learnToRankService = learnToRankServiceMock(); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(GOOD_MODEL, null, learnToRankService); rescorerBuilder.windowSize(4); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), @@ -142,9 +104,8 @@ public void testRewriteOnCoordinator() throws IOException { } public void testRewriteOnCoordinatorWithBadModel() throws IOException { - TestModelLoader testModelLoader = new TestModelLoader(); - ScriptService scriptService = mock(ScriptService.class); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(BAD_MODEL, null, testModelLoader, scriptService); + LearnToRankService learnToRankService = learnToRankServiceMock(); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(BAD_MODEL, null, learnToRankService); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), @@ -155,9 +116,8 @@ public void testRewriteOnCoordinatorWithBadModel() throws IOException { } public void testRewriteOnCoordinatorWithMissingModel() { - TestModelLoader testModelLoader = new TestModelLoader(); - ScriptService scriptService = mock(ScriptService.class); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder("missing_model", null, testModelLoader, scriptService); + LearnToRankService learnToRankService = learnToRankServiceMock(); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder("missing_model", null, learnToRankService); CoordinatorRewriteContext context = createCoordinatorRewriteContext( new DateFieldMapper.DateFieldType("@timestamp"), randomIntBetween(0, 1_100_000), @@ -166,55 +126,13 @@ public void testRewriteOnCoordinatorWithMissingModel() { expectThrows(ResourceNotFoundException.class, () -> rewriteAndFetch(rescorerBuilder, context)); } - public void testSearchRewrite() throws IOException { - LocalModel localModel = localModel(); - when(localModel.getModelId()).thenReturn(GOOD_MODEL); - - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(randomLearnToRankConfig(), localModel); - - QueryRewriteContext context = createSearchExecutionContext(); - LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) Rewriteable.rewrite(rescorerBuilder, context, true); - - LearnToRankConfig rewrittenLearnToRankConfig = Rewriteable.rewrite(rewritten.learnToRankConfig(), context); - assertThat(rewritten.localModel(), is(localModel)); - assertThat(rewritten.learnToRankConfig(), is(rewrittenLearnToRankConfig)); - } - - protected LearnToRankRescorerBuilder rewriteAndFetch(RescorerBuilder builder, QueryRewriteContext context) { - PlainActionFuture> future = new PlainActionFuture<>(); - Rewriteable.rewriteAndFetch(builder, context, future); - return (LearnToRankRescorerBuilder) future.actionGet(); - } - - @Override - protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException { - return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class)) - && (args[0] instanceof GetTrainedModelsAction); - } - - @Override - protected Object simulateMethod(Method method, Object[] args) { - GetTrainedModelsAction.Request request = (GetTrainedModelsAction.Request) args[1]; - @SuppressWarnings("unchecked") // We matched the method above. - ActionListener listener = (ActionListener) args[2]; - if (request.getResourceId().equals(GOOD_MODEL)) { - listener.onResponse(GetTrainedModelsAction.Response.builder().setModels(List.of(GOOD_MODEL_CONFIG)).build()); - return null; - } - if (request.getResourceId().equals(BAD_MODEL)) { - listener.onResponse(GetTrainedModelsAction.Response.builder().setModels(List.of(BAD_MODEL_CONFIG)).build()); - return null; - } - listener.onFailure(ExceptionsHelper.missingTrainedModel(request.getResourceId())); - return null; - } - public void testRewriteOnShard() throws IOException { - TestModelLoader testModelLoader = new TestModelLoader(); + LearnToRankService learnToRankService = learnToRankServiceMock(); LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( GOOD_MODEL, - testModelLoader, - (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig() + (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(), + null, + learnToRankService ); SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) rescorerBuilder.rewrite(createSearchExecutionContext()); @@ -223,9 +141,13 @@ public void testRewriteOnShard() throws IOException { } public void testRewriteAndFetchOnDataNode() throws IOException { - TestModelLoader testModelLoader = new TestModelLoader(); - - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(GOOD_MODEL, testModelLoader, randomLearnToRankConfig()); + LearnToRankService learnToRankService = learnToRankServiceMock(); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + GOOD_MODEL, + randomLearnToRankConfig(), + null, + learnToRankService + ); boolean setWindowSize = randomBoolean(); if (setWindowSize) { @@ -240,13 +162,44 @@ public void testRewriteAndFetchOnDataNode() throws IOException { } } + @SuppressWarnings("unchecked") + private static LearnToRankService learnToRankServiceMock() { + LearnToRankService learnToRankService = mock(LearnToRankService.class); + + doAnswer(invocation -> { + String modelId = invocation.getArgument(1); + ActionListener l = invocation.getArgument(3, ActionListener.class); + if (modelId.equals(GOOD_MODEL)) { + l.onResponse(GOOD_MODEL_CONFIG.getInferenceConfig()); + } else if (modelId.equals(BAD_MODEL)) { + l.onFailure(new ElasticsearchStatusException("bad model", RestStatus.BAD_REQUEST)); + } else { + l.onFailure(new ResourceNotFoundException("missing model")); + } + return null; + }).when(learnToRankService).loadLearnToRankConfig(isA(Client.class), anyString(), any(), any()); + + doAnswer(invocation -> { + ActionListener l = invocation.getArgument(1, ActionListener.class); + l.onResponse(mock(LocalModel.class)); + return null; + }).when(learnToRankService).loadLocalModel(anyString(), any()); + + return learnToRankService; + } + public void testBuildContext() throws Exception { - LocalModel localModel = localModel(); + LocalModel localModel = mock(LocalModel.class); List inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME); when(localModel.inputFields()).thenReturn(inputFields); SearchExecutionContext context = createSearchExecutionContext(); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(randomLearnToRankConfig(), localModel); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( + localModel, + randomLearnToRankConfig(), + null, + mock(LearnToRankService.class) + ); LearnToRankRescorerContext rescoreContext = rescorerBuilder.innerBuildContext(20, context); assertNotNull(rescoreContext); @@ -259,34 +212,9 @@ public void testBuildContext() throws Exception { ); } - private static LocalModel localModel() { - return mock(LocalModel.class); - } - - private static QueryExtractorBuilder queryExtractorBuilder(String featureName) throws IOException { - QueryProvider queryProvider = mock(QueryProvider.class); - when(queryProvider.rewrite(any())).thenReturn(queryProvider); - return new QueryExtractorBuilder(featureName, queryProvider); - } - - private static class TestModelLoader extends ModelLoadingService { - TestModelLoader() { - super( - mock(TrainedModelProvider.class), - mock(InferenceAuditor.class), - mock(ThreadPool.class), - mock(ClusterService.class), - mock(TrainedModelStatsService.class), - Settings.EMPTY, - "test", - mock(CircuitBreaker.class), - new XPackLicenseState(System::currentTimeMillis) - ); - } - - @Override - public void getModelForLearnToRank(String modelId, ActionListener modelActionListener) { - modelActionListener.onResponse(localModel()); - } + private LearnToRankRescorerBuilder rewriteAndFetch(RescorerBuilder builder, QueryRewriteContext context) { + PlainActionFuture> future = new PlainActionFuture<>(); + Rewriteable.rewriteAndFetch(builder, context, future); + return (LearnToRankRescorerBuilder) future.actionGet(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderSerializationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderSerializationTests.java index 046d075607cb5..77c8c9f4132b8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderSerializationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankRescorerBuilderSerializationTests.java @@ -27,13 +27,17 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.elasticsearch.search.rank.RankBuilder.WINDOW_SIZE_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class LearnToRankRescorerBuilderSerializationTests extends AbstractBWCSerializationTestCase { + private static LearnToRankService learnToRankService = mock(LearnToRankService.class); + @Override protected LearnToRankRescorerBuilder doParseInstance(XContentParser parser) throws IOException { String fieldName = null; @@ -51,7 +55,7 @@ protected LearnToRankRescorerBuilder doParseInstance(XContentParser parser) thro throw new ParsingException(parser.getTokenLocation(), "rescore doesn't support [" + fieldName + "]"); } } else if (token == XContentParser.Token.START_OBJECT) { - rescorer = LearnToRankRescorerBuilder.fromXContent(parser, null, null); + rescorer = LearnToRankRescorerBuilder.fromXContent(parser, learnToRankService); } else { throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "] after [" + fieldName + "]"); } @@ -67,14 +71,19 @@ protected LearnToRankRescorerBuilder doParseInstance(XContentParser parser) thro @Override protected Writeable.Reader instanceReader() { - return in -> new LearnToRankRescorerBuilder(in, null, null); + return in -> new LearnToRankRescorerBuilder(in, learnToRankService); } @Override protected LearnToRankRescorerBuilder createTestInstance() { LearnToRankRescorerBuilder builder = randomBoolean() ? createXContextTestInstance(null) - : new LearnToRankRescorerBuilder(randomAlphaOfLength(10), null, randomLearnToRankConfig()); + : new LearnToRankRescorerBuilder( + randomAlphaOfLength(10), + randomLearnToRankConfig(), + randomBoolean() ? randomParams() : null, + learnToRankService + ); if (randomBoolean()) { builder.windowSize(randomIntBetween(1, 10000)); @@ -85,12 +94,7 @@ protected LearnToRankRescorerBuilder createTestInstance() { @Override protected LearnToRankRescorerBuilder createXContextTestInstance(XContentType xContentType) { - return new LearnToRankRescorerBuilder( - randomAlphaOfLength(10), - randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier())), - null, - null - ); + return new LearnToRankRescorerBuilder(randomAlphaOfLength(10), randomBoolean() ? randomParams() : null, learnToRankService); } @Override @@ -102,23 +106,21 @@ protected LearnToRankRescorerBuilder mutateInstance(LearnToRankRescorerBuilder i LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( randomValueOtherThan(instance.modelId(), () -> randomAlphaOfLength(10)), instance.params(), - null, - null + learnToRankService ); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); } yield builder; } - case 1 -> new LearnToRankRescorerBuilder(instance.modelId(), instance.params(), null, null).windowSize( + case 1 -> new LearnToRankRescorerBuilder(instance.modelId(), instance.params(), learnToRankService).windowSize( randomValueOtherThan(instance.windowSize(), () -> randomIntBetween(1, 10000)) ); case 2 -> { LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( instance.modelId(), - randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier())), - null, - null + randomValueOtherThan(instance.params(), () -> (randomBoolean() ? randomParams() : null)), + learnToRankService ); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize() + 1); @@ -127,16 +129,24 @@ protected LearnToRankRescorerBuilder mutateInstance(LearnToRankRescorerBuilder i } case 3 -> { LearnToRankConfig learnToRankConfig = randomValueOtherThan(instance.learnToRankConfig(), () -> randomLearnToRankConfig()); - - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.modelId(), null, learnToRankConfig); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + instance.modelId(), + learnToRankConfig, + null, + learnToRankService + ); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); } yield builder; } case 4 -> { - LocalModel localModel = mock(LocalModel.class); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.learnToRankConfig(), localModel); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder( + mock(LocalModel.class), + instance.learnToRankConfig(), + instance.params(), + learnToRankService + ); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); } @@ -172,4 +182,15 @@ protected NamedWriteableRegistry writableRegistry() { protected NamedWriteableRegistry getNamedWriteableRegistry() { return writableRegistry(); } + + private static Map randomParams() { + return randomMap(1, randomIntBetween(1, 10), () -> new Tuple<>(randomIdentifier(), randomIdentifier())); + } + + private static LocalModel localModelMock() { + LocalModel model = mock(LocalModel.class); + String modelId = randomIdentifier(); + when(model.getModelId()).thenReturn(modelId); + return model; + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankServiceTests.java new file mode 100644 index 0000000000000..f9a7acf9673be --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/ltr/LearnToRankServiceTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.ltr; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.junit.After; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class LearnToRankServiceTests extends ESTestCase { + public static final String GOOD_MODEL = "modelId"; + public static final String BAD_MODEL = "badModel"; + public static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder() + .setModelId(GOOD_MODEL) + .setInput(new TrainedModelInput(List.of("field1", "field2"))) + .setEstimatedOperations(1) + .setModelSize(2) + .setModelType(TrainedModelType.TREE_ENSEMBLE) + .setInferenceConfig( + new LearnToRankConfig( + 2, + List.of( + new QueryExtractorBuilder("feature_1", new QueryProvider(Collections.emptyMap(), null, null)), + new QueryExtractorBuilder("feature_2", new QueryProvider(Collections.emptyMap(), null, null)) + ) + ) + ) + .build(); + public static final TrainedModelConfig BAD_MODEL_CONFIG = TrainedModelConfig.builder() + .setModelId(BAD_MODEL) + .setInput(new TrainedModelInput(List.of("field1", "field2"))) + .setEstimatedOperations(1) + .setModelSize(2) + .setModelType(TrainedModelType.TREE_ENSEMBLE) + .setInferenceConfig(new RegressionConfig(null, null)) + .build(); + + private ThreadPool threadPool; + private Client client; + + @Before + public void createRegistryAndClient() { + threadPool = new TestThreadPool(this.getClass().getName()); + client = mockClient(threadPool); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdownNow(); + } + + @SuppressWarnings("unchecked") + public void testLoadLearnToRankConfig() throws Exception { + LearnToRankService learnToRankService = new LearnToRankService(mockModelLoadingService(), mockScriptService(), xContentRegistry()); + ActionListener listener = mock(ActionListener.class); + learnToRankService.loadLearnToRankConfig(client, GOOD_MODEL, Collections.emptyMap(), listener); + assertBusy(() -> { verify(listener).onResponse((LearnToRankConfig) eq(GOOD_MODEL_CONFIG.getInferenceConfig())); }); + } + + public void testLoadMissingLearnToRankConfig() { + // TODO + } + + public void testLoadBadLearnToRankConfig() { + // TODO + } + + public void testLoadLearnToRankConfigWithTemplate() { + // TODO + } + + private ModelLoadingService mockModelLoadingService() { + return mock(ModelLoadingService.class); + } + + private ScriptService mockScriptService() { + return mock(ScriptService.class); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + private Client mockClient(ThreadPool threadPool) { + return new NoOpClient(threadPool) { + @Override + @SuppressWarnings("unchecked") + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action instanceof GetTrainedModelsAction) { + // Ignore this, it's verified in another test + GetTrainedModelsAction.Request getModelsRequest = (GetTrainedModelsAction.Request) request; + if (getModelsRequest.getResourceId().equals(GOOD_MODEL)) { + listener.onResponse( + (Response) GetTrainedModelsAction.Response.builder().setModels(List.of(GOOD_MODEL_CONFIG)).build() + ); + } else if (getModelsRequest.getResourceId().equals(BAD_MODEL)) { + listener.onResponse( + (Response) GetTrainedModelsAction.Response.builder().setModels(List.of(BAD_MODEL_CONFIG)).build() + ); + } else { + listener.onFailure(ExceptionsHelper.missingTrainedModel(getModelsRequest.getResourceId())); + } + } else { + fail("client called with unexpected request:" + request.toString()); + } + } + }; + } +}