Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] track inference model feature usage per node #79752

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,13 @@ void featureUsed(LicensedFeature feature) {
usage.put(new FeatureUsage(feature, null), epochMillisProvider.getAsLong());
}

void enableUsageTracking(LicensedFeature feature, String contextName) {
public void enableUsageTracking(LicensedFeature feature, String contextName) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rjernst I made these public for testing purposes, I need to check if tracking is enabled/disabled for inference in a different package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a shame to do this so TrainedModelDeploymentTaskTests can verify TrainedModelDeploymentTask calls start/stopTracking. Instead of statically importing MachineLearning.ML_MODEL_INFERENCE_FEATURE you could pass the LicensedFeature as a ctor parameter and mock a LicensedFeature in the tests which can be verified

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, I am passing in the value now as @davidkyle suggests, these changes are removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future reference there is a test class called MockLicenseState that makes these package private methods public for testing purposes, so there shouldn't ever be a need to make them public on XPackLicenseState.

checkExpiry();
Objects.requireNonNull(contextName, "Context name cannot be null");
usage.put(new FeatureUsage(feature, contextName), -1L);
}

void disableUsageTracking(LicensedFeature feature, String contextName) {
public void disableUsageTracking(LicensedFeature feature, String contextName) {
Objects.requireNonNull(contextName, "Context name cannot be null");
usage.replace(new FeatureUsage(feature, contextName), -1L, epochMillisProvider.getAsLong());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand All @@ -40,7 +41,8 @@ public TransportCreateTrainedModelAllocationAction(
ClusterService clusterService,
ThreadPool threadPool,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver
IndexNameExpressionResolver indexNameExpressionResolver,
XPackLicenseState licenseState
) {
super(
CreateTrainedModelAllocationAction.NAME,
Expand All @@ -62,7 +64,8 @@ public TransportCreateTrainedModelAllocationAction(
clusterService,
deploymentManager,
transportService.getTaskManager(),
threadPool
threadPool,
licenseState
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
Expand Down Expand Up @@ -65,6 +66,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
private final Map<String, TrainedModelDeploymentTask> modelIdToTask;
private final ThreadPool threadPool;
private final Deque<TrainedModelDeploymentTask> loadingModels;
private final XPackLicenseState licenseState;
private volatile Scheduler.Cancellable scheduledFuture;
private volatile boolean stopped;
private volatile String nodeId;
Expand All @@ -74,14 +76,16 @@ public TrainedModelAllocationNodeService(
ClusterService clusterService,
DeploymentManager deploymentManager,
TaskManager taskManager,
ThreadPool threadPool
ThreadPool threadPool,
XPackLicenseState licenseState
) {
this.trainedModelAllocationService = trainedModelAllocationService;
this.deploymentManager = deploymentManager;
this.taskManager = taskManager;
this.modelIdToTask = new ConcurrentHashMap<>();
this.loadingModels = new ConcurrentLinkedDeque<>();
this.threadPool = threadPool;
this.licenseState = licenseState;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
Expand All @@ -102,7 +106,8 @@ public void beforeStop() {
DeploymentManager deploymentManager,
TaskManager taskManager,
ThreadPool threadPool,
String nodeId
String nodeId,
XPackLicenseState licenseState
) {
this.trainedModelAllocationService = trainedModelAllocationService;
this.deploymentManager = deploymentManager;
Expand All @@ -111,6 +116,7 @@ public void beforeStop() {
this.loadingModels = new ConcurrentLinkedDeque<>();
this.threadPool = threadPool;
this.nodeId = nodeId;
this.licenseState = licenseState;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
Expand Down Expand Up @@ -265,7 +271,16 @@ public TaskId getParentTask() {

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService);
return new TrainedModelDeploymentTask(
id,
type,
action,
parentTaskId,
headers,
params,
trainedModelAllocationNodeService,
licenseState
);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.MlTasks;
Expand All @@ -26,6 +27,8 @@
import java.util.Map;
import java.util.Optional;

import static org.elasticsearch.xpack.ml.MachineLearning.ML_MODEL_INFERENCE_FEATURE;

public class TrainedModelDeploymentTask extends CancellableTask implements StartTrainedModelDeploymentAction.TaskMatcher {

private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class);
Expand All @@ -35,6 +38,7 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
private volatile boolean stopped;
private final SetOnce<String> stoppedReason = new SetOnce<>();
private final SetOnce<InferenceConfig> inferenceConfig = new SetOnce<>();
private final XPackLicenseState licenseState;

public TrainedModelDeploymentTask(
long id,
Expand All @@ -43,18 +47,21 @@ public TrainedModelDeploymentTask(
TaskId parentTask,
Map<String, String> headers,
TaskParams taskParams,
TrainedModelAllocationNodeService trainedModelAllocationNodeService
TrainedModelAllocationNodeService trainedModelAllocationNodeService,
XPackLicenseState licenseState
) {
super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers);
this.params = taskParams;
this.trainedModelAllocationNodeService = ExceptionsHelper.requireNonNull(
trainedModelAllocationNodeService,
"trainedModelAllocationNodeService"
);
this.licenseState = licenseState;
}

void init(InferenceConfig inferenceConfig) {
this.inferenceConfig.set(inferenceConfig);
ML_MODEL_INFERENCE_FEATURE.startTracking(licenseState, "model-" + params.getModelId());
}

public String getModelId() {
Expand All @@ -71,12 +78,14 @@ public TaskParams getParams() {

public void stop(String reason) {
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
ML_MODEL_INFERENCE_FEATURE.stopTracking(licenseState, "model-" + params.getModelId());
stopped = true;
stoppedReason.trySet(reason);
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
}

public void stopWithoutNotification(String reason) {
ML_MODEL_INFERENCE_FEATURE.stopTracking(licenseState, "model-" + params.getModelId());
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
stoppedReason.trySet(reason);
stopped = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
Expand Down Expand Up @@ -507,7 +508,8 @@ private TrainedModelAllocationNodeService createService() {
deploymentManager,
taskManager,
threadPool,
NODE_ID
NODE_ID,
mock(XPackLicenseState.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;

import java.util.Map;
import java.util.function.Consumer;

import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
import static org.elasticsearch.xpack.ml.MachineLearning.ML_MODEL_INFERENCE_FEATURE;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class TrainedModelDeploymentTaskTests extends ESTestCase {

void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
0,
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + modelId,
TaskId.EMPTY_TASK_ID,
Map.of(),
new StartTrainedModelDeploymentAction.TaskParams(
modelId,
randomLongBetween(1, Long.MAX_VALUE),
randomInt(5),
randomInt(5),
randomInt(5)
),
mock(TrainedModelAllocationNodeService.class),
licenseState);

task.init(new PassThroughConfig(null, null, null));
verify(licenseState, times(1)).enableUsageTracking(ML_MODEL_INFERENCE_FEATURE, "model-" + modelId);
method.accept(task);
verify(licenseState, times(1)).disableUsageTracking(ML_MODEL_INFERENCE_FEATURE, "model-" + modelId);
}

public void testOnStopWithoutNotification() {
assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
}

public void testOnStop() {
assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
}

public void testCancelled() {
assertTrackingComplete(TrainedModelDeploymentTask::onCancelled, randomAlphaOfLength(10));
}

}