Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] track inference model feature usage per node #79752

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
"model-inference",
License.OperationMode.PLATINUM
);
public static final LicensedFeature.Persistent ML_PYTORCH_MODEL_INFERENCE_FEATURE = LicensedFeature.persistent(
MachineLearningField.ML_FEATURE_FAMILY,
"pytorch-model-inference",
License.OperationMode.PLATINUM
);

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand All @@ -40,7 +41,8 @@ public TransportCreateTrainedModelAllocationAction(
ClusterService clusterService,
ThreadPool threadPool,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver
IndexNameExpressionResolver indexNameExpressionResolver,
XPackLicenseState licenseState
) {
super(
CreateTrainedModelAllocationAction.NAME,
Expand All @@ -62,7 +64,8 @@ public TransportCreateTrainedModelAllocationAction(
clusterService,
deploymentManager,
transportService.getTaskManager(),
threadPool
threadPool,
licenseState
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
Expand Down Expand Up @@ -52,6 +53,7 @@

import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
import static org.elasticsearch.xpack.ml.MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE;

public class TrainedModelAllocationNodeService implements ClusterStateListener {

Expand All @@ -65,6 +67,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
private final Map<String, TrainedModelDeploymentTask> modelIdToTask;
private final ThreadPool threadPool;
private final Deque<TrainedModelDeploymentTask> loadingModels;
private final XPackLicenseState licenseState;
private volatile Scheduler.Cancellable scheduledFuture;
private volatile boolean stopped;
private volatile String nodeId;
Expand All @@ -74,14 +77,16 @@ public TrainedModelAllocationNodeService(
ClusterService clusterService,
DeploymentManager deploymentManager,
TaskManager taskManager,
ThreadPool threadPool
ThreadPool threadPool,
XPackLicenseState licenseState
) {
this.trainedModelAllocationService = trainedModelAllocationService;
this.deploymentManager = deploymentManager;
this.taskManager = taskManager;
this.modelIdToTask = new ConcurrentHashMap<>();
this.loadingModels = new ConcurrentLinkedDeque<>();
this.threadPool = threadPool;
this.licenseState = licenseState;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
Expand All @@ -102,7 +107,8 @@ public void beforeStop() {
DeploymentManager deploymentManager,
TaskManager taskManager,
ThreadPool threadPool,
String nodeId
String nodeId,
XPackLicenseState licenseState
) {
this.trainedModelAllocationService = trainedModelAllocationService;
this.deploymentManager = deploymentManager;
Expand All @@ -111,6 +117,7 @@ public void beforeStop() {
this.loadingModels = new ConcurrentLinkedDeque<>();
this.threadPool = threadPool;
this.nodeId = nodeId;
this.licenseState = licenseState;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
Expand Down Expand Up @@ -265,7 +272,17 @@ public TaskId getParentTask() {

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService);
return new TrainedModelDeploymentTask(
id,
type,
action,
parentTaskId,
headers,
params,
trainedModelAllocationNodeService,
licenseState,
ML_PYTORCH_MODEL_INFERENCE_FEATURE
);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.MlTasks;
Expand All @@ -26,6 +28,7 @@
import java.util.Map;
import java.util.Optional;


public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher {

private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class);
Expand All @@ -35,6 +38,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
private volatile boolean stopped;
private final SetOnce<String> stoppedReason = new SetOnce<>();
private final SetOnce<InferenceConfig> inferenceConfig = new SetOnce<>();
private final XPackLicenseState licenseState;
private final LicensedFeature.Persistent licensedFeature;

public TrainedModelDeploymentTask(
long id,
Expand All @@ -43,18 +48,23 @@ public TrainedModelDeploymentTask(
TaskId parentTask,
Map<String, String> headers,
TaskParams taskParams,
TrainedModelAllocationNodeService trainedModelAllocationNodeService
TrainedModelAllocationNodeService trainedModelAllocationNodeService,
XPackLicenseState licenseState,
LicensedFeature.Persistent licensedFeature
) {
super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers);
this.params = taskParams;
this.trainedModelAllocationNodeService = ExceptionsHelper.requireNonNull(
trainedModelAllocationNodeService,
"trainedModelAllocationNodeService"
);
this.licenseState = licenseState;
this.licensedFeature = licensedFeature;
}

void init(InferenceConfig inferenceConfig) {
this.inferenceConfig.set(inferenceConfig);
licensedFeature.startTracking(licenseState, "model-" + params.getModelId());
}

public String getModelId() {
Expand All @@ -71,12 +81,14 @@ public TaskParams getParams() {

public void stop(String reason) {
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
stopped = true;
stoppedReason.trySet(reason);
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
}

public void stopWithoutNotification(String reason) {
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
stoppedReason.trySet(reason);
stopped = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
Expand Down Expand Up @@ -507,7 +508,8 @@ private TrainedModelAllocationNodeService createService() {
deploymentManager,
taskManager,
threadPool,
NODE_ID
NODE_ID,
mock(XPackLicenseState.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;

import java.util.Map;
import java.util.function.Consumer;

import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class TrainedModelDeploymentTaskTests extends ESTestCase {

void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class);
TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
0,
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + modelId,
TaskId.EMPTY_TASK_ID,
Map.of(),
new StartTrainedModelDeploymentAction.TaskParams(
modelId,
randomLongBetween(1, Long.MAX_VALUE),
randomInt(5),
randomInt(5),
randomInt(5)
),
mock(TrainedModelAllocationNodeService.class),
licenseState,
feature
);

task.init(new PassThroughConfig(null, null, null));
verify(feature, times(1)).startTracking(licenseState, "model-" + modelId);
method.accept(task);
verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId);
}

public void testOnStopWithoutNotification() {
assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
}

public void testOnStop() {
assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
}

public void testCancelled() {
assertTrackingComplete(TrainedModelDeploymentTask::onCancelled, randomAlphaOfLength(10));
}

}