From b34279c66deb983cf083d64085c6e2c88a133b61 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 26 Oct 2021 08:32:56 -0400 Subject: [PATCH] [ML] track inference model feature usage per node (#79752) This adds feature usage tracking for deployed inference models. The models are tracked under the existing, inference feature and contain context related to the model ID. I decided to track the feature via the allocation task to keep the logic similar between allocation tasks and licensed persistent tasks. closes: https://github.com/elastic/elasticsearch/issues/76452 --- .../xpack/ml/MachineLearning.java | 5 ++ ...ortCreateTrainedModelAllocationAction.java | 7 +- .../TrainedModelAllocationNodeService.java | 23 ++++++- .../TrainedModelDeploymentTask.java | 14 +++- ...rainedModelAllocationNodeServiceTests.java | 4 +- .../TrainedModelDeploymentTaskTests.java | 68 +++++++++++++++++++ 6 files changed, 114 insertions(+), 7 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index a26f65bfeac94..d311e5f3aea13 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -468,6 +468,11 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, "model-inference", License.OperationMode.PLATINUM ); + public static final LicensedFeature.Persistent ML_PYTORCH_MODEL_INFERENCE_FEATURE = LicensedFeature.persistent( + MachineLearningField.ML_FEATURE_FAMILY, + "pytorch-model-inference", + License.OperationMode.PLATINUM + ); @Override public Map getProcessors(Processor.Parameters parameters) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java index 45861de90a87f..794d8de8229b0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAllocationAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -40,7 +41,8 @@ public TransportCreateTrainedModelAllocationAction( ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver + IndexNameExpressionResolver indexNameExpressionResolver, + XPackLicenseState licenseState ) { super( CreateTrainedModelAllocationAction.NAME, @@ -62,7 +64,8 @@ public TransportCreateTrainedModelAllocationAction( clusterService, deploymentManager, transportService.getTaskManager(), - threadPool + threadPool, + licenseState ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java index a40aad7a7f075..1c07ed6d1bea3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskId; @@ -52,6 +53,7 @@ import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX; import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE; +import static org.elasticsearch.xpack.ml.MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE; public class TrainedModelAllocationNodeService implements ClusterStateListener { @@ -65,6 +67,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener { private final Map modelIdToTask; private final ThreadPool threadPool; private final Deque loadingModels; + private final XPackLicenseState licenseState; private volatile Scheduler.Cancellable scheduledFuture; private volatile boolean stopped; private volatile String nodeId; @@ -74,7 +77,8 @@ public TrainedModelAllocationNodeService( ClusterService clusterService, DeploymentManager deploymentManager, TaskManager taskManager, - ThreadPool threadPool + ThreadPool threadPool, + XPackLicenseState licenseState ) { this.trainedModelAllocationService = trainedModelAllocationService; this.deploymentManager = deploymentManager; @@ -82,6 +86,7 @@ public TrainedModelAllocationNodeService( this.modelIdToTask = new ConcurrentHashMap<>(); this.loadingModels = new ConcurrentLinkedDeque<>(); this.threadPool = threadPool; + this.licenseState = licenseState; clusterService.addLifecycleListener(new LifecycleListener() { @Override public void afterStart() { @@ -102,7 +107,8 @@ public void beforeStop() { DeploymentManager deploymentManager, TaskManager taskManager, ThreadPool threadPool, - String nodeId + String nodeId, + XPackLicenseState licenseState ) { this.trainedModelAllocationService = trainedModelAllocationService; this.deploymentManager = deploymentManager; @@ -111,6 +117,7 @@ public void beforeStop() { this.loadingModels = new ConcurrentLinkedDeque<>(); this.threadPool = threadPool; this.nodeId = nodeId; + this.licenseState = licenseState; clusterService.addLifecycleListener(new LifecycleListener() { @Override public void afterStart() { @@ -265,7 +272,17 @@ public TaskId getParentTask() { @Override public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService); + return new TrainedModelDeploymentTask( + id, + type, + action, + parentTaskId, + headers, + params, + trainedModelAllocationNodeService, + licenseState, + ML_PYTORCH_MODEL_INFERENCE_FEATURE + ); } }; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index f7c1c674669a2..a65e85ab05da0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -12,6 +12,8 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.LicensedFeature; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.MlTasks; @@ -26,6 +28,7 @@ import java.util.Map; import java.util.Optional; + public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher { private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class); @@ -35,6 +38,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start private volatile boolean stopped; private final SetOnce stoppedReason = new SetOnce<>(); private final SetOnce inferenceConfig = new SetOnce<>(); + private final XPackLicenseState licenseState; + private final LicensedFeature.Persistent licensedFeature; public TrainedModelDeploymentTask( long id, @@ -43,7 +48,9 @@ public TrainedModelDeploymentTask( TaskId parentTask, Map headers, TaskParams taskParams, - TrainedModelAllocationNodeService trainedModelAllocationNodeService + TrainedModelAllocationNodeService trainedModelAllocationNodeService, + XPackLicenseState licenseState, + LicensedFeature.Persistent licensedFeature ) { super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers); this.params = taskParams; @@ -51,10 +58,13 @@ public TrainedModelDeploymentTask( trainedModelAllocationNodeService, "trainedModelAllocationNodeService" ); + this.licenseState = licenseState; + this.licensedFeature = licensedFeature; } void init(InferenceConfig inferenceConfig) { this.inferenceConfig.set(inferenceConfig); + licensedFeature.startTracking(licenseState, "model-" + params.getModelId()); } public String getModelId() { @@ -71,12 +81,14 @@ public TaskParams getParams() { public void stop(String reason) { logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); + licensedFeature.stopTracking(licenseState, "model-" + params.getModelId()); stopped = true; stoppedReason.trySet(reason); trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason); } public void stopWithoutNotification(String reason) { + licensedFeature.stopTracking(licenseState, "model-" + params.getModelId()); logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); stoppedReason.trySet(reason); stopped = true; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java index 0b8bffdf302aa..2a2d07581e2e6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ScalingExecutorBuilder; @@ -507,7 +508,8 @@ private TrainedModelAllocationNodeService createService() { deploymentManager, taskManager, threadPool, - NODE_ID + NODE_ID, + mock(XPackLicenseState.class) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java new file mode 100644 index 0000000000000..7aa4f2c9a88ca --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.deployment; + +import org.elasticsearch.license.LicensedFeature; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService; + +import java.util.Map; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX; +import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class TrainedModelDeploymentTaskTests extends ESTestCase { + + void assertTrackingComplete(Consumer method, String modelId) { + XPackLicenseState licenseState = mock(XPackLicenseState.class); + LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class); + TrainedModelDeploymentTask task = new TrainedModelDeploymentTask( + 0, + TRAINED_MODEL_ALLOCATION_TASK_TYPE, + TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + modelId, + TaskId.EMPTY_TASK_ID, + Map.of(), + new StartTrainedModelDeploymentAction.TaskParams( + modelId, + randomLongBetween(1, Long.MAX_VALUE), + randomInt(5), + randomInt(5), + randomInt(5) + ), + mock(TrainedModelAllocationNodeService.class), + licenseState, + feature + ); + + task.init(new PassThroughConfig(null, null, null)); + verify(feature, times(1)).startTracking(licenseState, "model-" + modelId); + method.accept(task); + verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId); + } + + public void testOnStopWithoutNotification() { + assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10)); + } + + public void testOnStop() { + assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10)); + } + + public void testCancelled() { + assertTrackingComplete(TrainedModelDeploymentTask::onCancelled, randomAlphaOfLength(10)); + } + +}