From 85872e0389b49449b4e9e3f1323cc625bbfbdfc7 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 25 Oct 2021 13:32:21 -0400 Subject: [PATCH 1/3] [ML] track inference model deployment on nodes --- ...ortCreateTrainedModelAllocationAction.java | 7 +++++-- .../TrainedModelAllocationNodeService.java | 21 ++++++++++++++++--- .../TrainedModelDeploymentTask.java | 11 +++++++++- ...rainedModelAllocationNodeServiceTests.java | 4 +++- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java index 45861de90a87..794d8de8229b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -40,7 +41,8 @@ public TransportCreateTrainedModelAllocationAction( ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver + IndexNameExpressionResolver indexNameExpressionResolver, + XPackLicenseState licenseState ) { super( CreateTrainedModelAllocationAction.NAME, @@ -62,7 +64,8 @@ public TransportCreateTrainedModelAllocationAction( clusterService, deploymentManager, transportService.getTaskManager(), - threadPool + threadPool, + licenseState ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java index a40aad7a7f07..b712fea2c26e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskId; @@ -65,6 +66,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener { private final Map modelIdToTask; private final ThreadPool threadPool; private final Deque loadingModels; + private final XPackLicenseState licenseState; private volatile Scheduler.Cancellable scheduledFuture; private volatile boolean stopped; private volatile String nodeId; @@ -74,7 +76,8 @@ public TrainedModelAllocationNodeService( ClusterService clusterService, DeploymentManager deploymentManager, TaskManager taskManager, - ThreadPool threadPool + ThreadPool threadPool, + XPackLicenseState licenseState ) { this.trainedModelAllocationService = trainedModelAllocationService; this.deploymentManager = deploymentManager; @@ -82,6 +85,7 @@ public TrainedModelAllocationNodeService( this.modelIdToTask = new ConcurrentHashMap<>(); this.loadingModels = new ConcurrentLinkedDeque<>(); this.threadPool = threadPool; + this.licenseState = licenseState; clusterService.addLifecycleListener(new LifecycleListener() { @Override public void afterStart() { @@ -102,7 +106,8 @@ public void beforeStop() { DeploymentManager deploymentManager, TaskManager taskManager, ThreadPool threadPool, - String nodeId + String nodeId, + XPackLicenseState licenseState ) { this.trainedModelAllocationService = trainedModelAllocationService; this.deploymentManager = deploymentManager; @@ -111,6 +116,7 @@ public void beforeStop() { this.loadingModels = new ConcurrentLinkedDeque<>(); this.threadPool = threadPool; this.nodeId = nodeId; + this.licenseState = licenseState; clusterService.addLifecycleListener(new LifecycleListener() { @Override public void afterStart() { @@ -265,7 +271,16 @@ public TaskId getParentTask() { @Override public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService); + return new TrainedModelDeploymentTask( + id, + type, + action, + parentTaskId, + headers, + params, + trainedModelAllocationNodeService, + licenseState + ); } }; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index f7c1c674669a..cad678d46403 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -12,6 +12,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.MlTasks; @@ -26,6 +27,8 @@ import java.util.Map; import java.util.Optional; +import static org.elasticsearch.xpack.ml.MachineLearning.ML_MODEL_INFERENCE_FEATURE; + public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher { private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); @@ -35,6 +38,7 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start private volatile boolean stopped; private final SetOnce stoppedReason = new SetOnce<>(); private final SetOnce inferenceConfig = new SetOnce<>(); + private final XPackLicenseState licenseState; public TrainedModelDeploymentTask( long id, @@ -43,7 +47,8 @@ public TrainedModelDeploymentTask( TaskId parentTask, Map headers, TaskParams taskParams, - TrainedModelAllocationNodeService trainedModelAllocationNodeService + TrainedModelAllocationNodeService trainedModelAllocationNodeService, + XPackLicenseState licenseState ) { super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers); this.params = taskParams; @@ -51,6 +56,8 @@ public TrainedModelDeploymentTask( trainedModelAllocationNodeService, "trainedModelAllocationNodeService" ); + this.licenseState = licenseState; + ML_MODEL_INFERENCE_FEATURE.startTracking(licenseState, "model-" + taskParams.getModelId()); } void init(InferenceConfig inferenceConfig) { @@ -71,12 +78,14 @@ public TaskParams getParams() { public void stop(String reason) { logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + ML_MODEL_INFERENCE_FEATURE.stopTracking(licenseState, "model-" + params.getModelId()); stopped = true; stoppedReason.trySet(reason); trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason); } public void stopWithoutNotification(String reason) { + ML_MODEL_INFERENCE_FEATURE.stopTracking(licenseState, "model-" + params.getModelId()); logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); stoppedReason.trySet(reason); stopped = true; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java index 0b8bffdf302a..2a2d07581e2e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ScalingExecutorBuilder; @@ -507,7 +508,8 @@ private TrainedModelAllocationNodeService createService() { deploymentManager, taskManager, threadPool, - NODE_ID + NODE_ID, + mock(XPackLicenseState.class) ); } From 13ef48f14ff43a7741ae757899c9716b7dedeeaa Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 25 Oct 2021 15:14:15 -0400 Subject: [PATCH 2/3] [ML] track inference model feature usage per node --- .../license/XPackLicenseState.java | 4 +- .../TrainedModelDeploymentTask.java | 2 +- .../TrainedModelDeploymentTaskTests.java | 65 +++++++++++++++++++ 3 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java index 0af7e8dd1d35..c191fbd4ce5a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java @@ -406,13 +406,13 @@ void featureUsed(LicensedFeature feature) { usage.put(new FeatureUsage(feature, null), epochMillisProvider.getAsLong()); } - void enableUsageTracking(LicensedFeature feature, String contextName) { + public void enableUsageTracking(LicensedFeature feature, String contextName) { checkExpiry(); Objects.requireNonNull(contextName, "Context name cannot be null"); usage.put(new FeatureUsage(feature, contextName), -1L); } - void disableUsageTracking(LicensedFeature feature, String contextName) { + public void disableUsageTracking(LicensedFeature feature, String contextName) { Objects.requireNonNull(contextName, "Context name cannot be null"); usage.replace(new FeatureUsage(feature, contextName), -1L, epochMillisProvider.getAsLong()); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index cad678d46403..5dfad5c05dea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -57,11 +57,11 @@ public TrainedModelDeploymentTask( "trainedModelAllocationNodeService" ); this.licenseState = licenseState; - ML_MODEL_INFERENCE_FEATURE.startTracking(licenseState, "model-" + taskParams.getModelId()); } void init(InferenceConfig inferenceConfig) { this.inferenceConfig.set(inferenceConfig); + ML_MODEL_INFERENCE_FEATURE.startTracking(licenseState, "model-" + params.getModelId()); } public String getModelId() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java new file mode 100644 index 000000000000..e3303f2e8374 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService; + +import java.util.Map; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX; +import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE; +import static org.elasticsearch.xpack.ml.MachineLearning.ML_MODEL_INFERENCE_FEATURE; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class TrainedModelDeploymentTaskTests extends ESTestCase { + + void assertTrackingComplete(Consumer method, String modelId) { + XPackLicenseState licenseState = mock(XPackLicenseState.class); + TrainedModelDeploymentTask task = new TrainedModelDeploymentTask( + 0, + TRAINED_MODEL_ALLOCATION_TASK_TYPE, + TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + modelId, + TaskId.EMPTY_TASK_ID, + Map.of(), + new StartTrainedModelDeploymentAction.TaskParams( + modelId, + randomLongBetween(1, Long.MAX_VALUE), + randomInt(5), + randomInt(5), + randomInt(5) + ), + mock(TrainedModelAllocationNodeService.class), + licenseState); + + task.init(new PassThroughConfig(null, null, null)); + verify(licenseState, times(1)).enableUsageTracking(ML_MODEL_INFERENCE_FEATURE, "model-" + modelId); + method.accept(task); + verify(licenseState, times(1)).disableUsageTracking(ML_MODEL_INFERENCE_FEATURE, "model-" + modelId); + } + + public void testOnStopWithoutNotification() { + assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10)); + } + + public void testOnStop() { + assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10)); + } + + public void testCancelled() { + assertTrackingComplete(TrainedModelDeploymentTask::onCancelled, randomAlphaOfLength(10)); + } + +} From 6b115db3911f1ceac21021fe9982ec40cd6eee4b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 26 Oct 2021 07:46:11 -0400 Subject: [PATCH 3/3] addressing PR comments --- .../elasticsearch/license/XPackLicenseState.java | 4 ++-- .../org/elasticsearch/xpack/ml/MachineLearning.java | 5 +++++ .../TrainedModelAllocationNodeService.java | 4 +++- .../deployment/TrainedModelDeploymentTask.java | 13 ++++++++----- .../deployment/TrainedModelDeploymentTaskTests.java | 11 +++++++---- 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java index c191fbd4ce5a..0af7e8dd1d35 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java @@ -406,13 +406,13 @@ void featureUsed(LicensedFeature feature) { usage.put(new FeatureUsage(feature, null), epochMillisProvider.getAsLong()); } - public void enableUsageTracking(LicensedFeature feature, String contextName) { + void enableUsageTracking(LicensedFeature feature, String contextName) { checkExpiry(); Objects.requireNonNull(contextName, "Context name cannot be null"); usage.put(new FeatureUsage(feature, contextName), -1L); } - public void disableUsageTracking(LicensedFeature feature, String contextName) { + void disableUsageTracking(LicensedFeature feature, String contextName) { Objects.requireNonNull(contextName, "Context name cannot be null"); usage.replace(new FeatureUsage(feature, contextName), -1L, epochMillisProvider.getAsLong()); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index a26f65bfeac9..d311e5f3aea1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -468,6 +468,11 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, "model-inference", License.OperationMode.PLATINUM ); + public static final LicensedFeature.Persistent ML_PYTORCH_MODEL_INFERENCE_FEATURE = LicensedFeature.persistent( + MachineLearningField.ML_FEATURE_FAMILY, + "pytorch-model-inference", + License.OperationMode.PLATINUM + ); @Override public Map getProcessors(Processor.Parameters parameters) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java index b712fea2c26e..1c07ed6d1bea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java @@ -53,6 +53,7 @@ import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX; import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE; +import static org.elasticsearch.xpack.ml.MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE; public class TrainedModelAllocationNodeService implements ClusterStateListener { @@ -279,7 +280,8 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId, headers, params, trainedModelAllocationNodeService, - licenseState + licenseState, + ML_PYTORCH_MODEL_INFERENCE_FEATURE ); } }; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 5dfad5c05dea..a65e85ab05da 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -12,6 +12,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; @@ -27,7 +28,6 @@ import java.util.Map; import java.util.Optional; -import static org.elasticsearch.xpack.ml.MachineLearning.ML_MODEL_INFERENCE_FEATURE; public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher { @@ -39,6 +39,7 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start private final SetOnce stoppedReason = new SetOnce<>(); private final SetOnce inferenceConfig = new SetOnce<>(); private final XPackLicenseState licenseState; + private final LicensedFeature.Persistent licensedFeature; public TrainedModelDeploymentTask( long id, @@ -48,7 +49,8 @@ public TrainedModelDeploymentTask( Map headers, TaskParams taskParams, TrainedModelAllocationNodeService trainedModelAllocationNodeService, - XPackLicenseState licenseState + XPackLicenseState licenseState, + LicensedFeature.Persistent licensedFeature ) { super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers); this.params = taskParams; @@ -57,11 +59,12 @@ public TrainedModelDeploymentTask( "trainedModelAllocationNodeService" ); this.licenseState = licenseState; + this.licensedFeature = licensedFeature; } void init(InferenceConfig inferenceConfig) { this.inferenceConfig.set(inferenceConfig); - ML_MODEL_INFERENCE_FEATURE.startTracking(licenseState, "model-" + params.getModelId()); + licensedFeature.startTracking(licenseState, "model-" + params.getModelId()); } public String getModelId() { @@ -78,14 +81,14 @@ public TaskParams getParams() { public void stop(String reason) { logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); - ML_MODEL_INFERENCE_FEATURE.stopTracking(licenseState, "model-" + params.getModelId()); + licensedFeature.stopTracking(licenseState, "model-" + params.getModelId()); stopped = true; stoppedReason.trySet(reason); trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason); } public void stopWithoutNotification(String reason) { - ML_MODEL_INFERENCE_FEATURE.stopTracking(licenseState, "model-" + params.getModelId()); + licensedFeature.stopTracking(licenseState, "model-" + params.getModelId()); logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); stoppedReason.trySet(reason); stopped = true; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java index e3303f2e8374..7aa4f2c9a88c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.deployment; +import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; @@ -19,7 +20,6 @@ import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX; import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE; -import static org.elasticsearch.xpack.ml.MachineLearning.ML_MODEL_INFERENCE_FEATURE; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -28,6 +28,7 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase { void assertTrackingComplete(Consumer method, String modelId) { XPackLicenseState licenseState = mock(XPackLicenseState.class); + LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class); TrainedModelDeploymentTask task = new TrainedModelDeploymentTask( 0, TRAINED_MODEL_ALLOCATION_TASK_TYPE, @@ -42,12 +43,14 @@ void assertTrackingComplete(Consumer method, String randomInt(5) ), mock(TrainedModelAllocationNodeService.class), - licenseState); + licenseState, + feature + ); task.init(new PassThroughConfig(null, null, null)); - verify(licenseState, times(1)).enableUsageTracking(ML_MODEL_INFERENCE_FEATURE, "model-" + modelId); + verify(feature, times(1)).startTracking(licenseState, "model-" + modelId); method.accept(task); - verify(licenseState, times(1)).disableUsageTracking(ML_MODEL_INFERENCE_FEATURE, "model-" + modelId); + verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId); } public void testOnStopWithoutNotification() {