From ff72b7748de0bdd72a48d97fc4dbed63c7fc6611 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Wed, 10 Apr 2024 09:14:03 +0200 Subject: [PATCH] Implement mode cache flush from any node. --- .../action/FlushTrainedModelCacheAction.java | 35 +++++++++++ .../xpack/ml/MachineLearning.java | 8 ++- ...TransportFlushTrainedModelCacheAction.java | 62 +++++++++++++++++++ .../loadingservice/ModelLoadingService.java | 2 +- .../TrainedModelCacheMetadataService.java | 30 +++++---- 5 files changed, 124 insertions(+), 13 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java new file mode 100644 index 0000000000000..0f0866a8834e2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/FlushTrainedModelCacheAction.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; + +import java.io.IOException; + +public class FlushTrainedModelCacheAction extends ActionType { + + public static final FlushTrainedModelCacheAction INSTANCE = new FlushTrainedModelCacheAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/clear_model_cache"; + + private FlushTrainedModelCacheAction() { + super(NAME); + } + + public static class Request extends AcknowledgedRequest { + public Request() { + super(); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index b1dc2c829260e..567671beb35ad 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -121,6 +121,7 @@ import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; import org.elasticsearch.xpack.core.ml.action.ForecastJobAction; import org.elasticsearch.xpack.core.ml.action.GetBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetCalendarEventsAction; @@ -228,6 +229,7 @@ import org.elasticsearch.xpack.ml.action.TransportExternalInferModelAction; import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction; import org.elasticsearch.xpack.ml.action.TransportFlushJobAction; +import org.elasticsearch.xpack.ml.action.TransportFlushTrainedModelCacheAction; import org.elasticsearch.xpack.ml.action.TransportForecastJobAction; import org.elasticsearch.xpack.ml.action.TransportGetBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetCalendarEventsAction; @@ -1134,7 +1136,10 @@ public Collection createComponents(PluginServices services) { clusterService, threadPool ); - final TrainedModelCacheMetadataService trainedModelCacheMetadataService = new TrainedModelCacheMetadataService(clusterService); + final TrainedModelCacheMetadataService trainedModelCacheMetadataService = new TrainedModelCacheMetadataService( + clusterService, + client + ); final TrainedModelProvider trainedModelProvider = new TrainedModelProvider( client, trainedModelCacheMetadataService, @@ -1569,6 +1574,7 @@ public List getRestHandlers( actionHandlers.add( new ActionHandler<>(PutTrainedModelDefinitionPartAction.INSTANCE, TransportPutTrainedModelDefinitionPartAction.class) ); + actionHandlers.add(new ActionHandler<>(FlushTrainedModelCacheAction.INSTANCE, TransportFlushTrainedModelCacheAction.class)); actionHandlers.add(new ActionHandler<>(InferModelAction.INSTANCE, TransportInternalInferModelAction.class)); actionHandlers.add(new ActionHandler<>(InferModelAction.EXTERNAL_INSTANCE, TransportExternalInferModelAction.class)); actionHandlers.add(new ActionHandler<>(GetDeploymentStatsAction.INSTANCE, TransportGetDeploymentStatsAction.class)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java new file mode 100644 index 0000000000000..a7b2bbfadd2b5 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; + +public class TransportFlushTrainedModelCacheAction extends AcknowledgedTransportMasterNodeAction { + + @Inject + public TransportFlushTrainedModelCacheAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + FlushTrainedModelCacheAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + FlushTrainedModelCacheAction.Request::new, + indexNameExpressionResolver, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + } + + @Override + protected void masterOperation( + Task task, + FlushTrainedModelCacheAction.Request request, + ClusterState state, + ActionListener listener + ) { + // TODO + } + + @Override + protected ClusterBlockException checkBlock(FlushTrainedModelCacheAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 196dd7dff485c..a4ac320066f81 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -753,7 +753,7 @@ private void cacheEvictionListener(RemovalNotification public void clusterChanged(ClusterChangedEvent event) { if (event.changedCustomMetadataSet().contains(TrainedModelCacheMetadata.NAME)) { // Flush all models cache since we are detecting some changes. - logger.debug("Trained model cache invalidated on node [{}]", event.state().nodes().getLocalNodeId()); + logger.trace("Trained model cache invalidated on node [{}]", event.state().nodes().getLocalNodeId()); localModelCache.invalidateAll(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java index 69857eef3b8ac..844d58f420c4f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java @@ -9,6 +9,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; @@ -23,30 +25,34 @@ import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + public class TrainedModelCacheMetadataService implements ClusterStateListener { private static final Logger LOGGER = LogManager.getLogger(TrainedModelCacheMetadataService.class); - private final MasterServiceTaskQueue modelCacheMetadataUpdateTaskQueue; + private static final String TASK_QUEUE_NAME = "trained-models-cache-metadata-management"; + private final MasterServiceTaskQueue metadataUpdateTaskQueue; + private final Client client; private volatile boolean isMasterNode = false; - public TrainedModelCacheMetadataService(ClusterService clusterService) { - this.modelCacheMetadataUpdateTaskQueue = clusterService.createTaskQueue( - "trained-models-cache-metadata-management", - Priority.IMMEDIATE, - new TrainedModelCacheMetadataTaskExecutor() - ); + public TrainedModelCacheMetadataService(ClusterService clusterService, Client client) { + this.client = new OriginSettingClient(client, ML_ORIGIN); + ; + TrainedModelCacheMetadataUpdateTaskExecutor metadataUpdateTaskExecutor = new TrainedModelCacheMetadataUpdateTaskExecutor(); + this.metadataUpdateTaskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, metadataUpdateTaskExecutor); clusterService.addListener(this); } public void refreshCacheVersion(ActionListener listener) { if (this.isMasterNode == false) { - // TODO: Use an internal action to update the cache version - listener.onResponse(AcknowledgedResponse.FALSE); + client.execute(FlushTrainedModelCacheAction.INSTANCE, new FlushTrainedModelCacheAction.Request(), listener); return; } + TrainedModelCacheMetadataUpdateTask updateMetadataTask = new RefreshTrainedModeCacheMetadataVersionTask(listener); - this.modelCacheMetadataUpdateTaskQueue.submitTask(updateMetadataTask.getDescription(), updateMetadataTask, null); + this.metadataUpdateTaskQueue.submitTask(updateMetadataTask.getDescription(), updateMetadataTask, null); } @Override @@ -99,7 +105,9 @@ protected String getDescription() { } } - private static class TrainedModelCacheMetadataTaskExecutor implements ClusterStateTaskExecutor { + private static class TrainedModelCacheMetadataUpdateTaskExecutor + implements + ClusterStateTaskExecutor { @Override public ClusterState execute(BatchExecutionContext batchExecutionContext) { final var initialState = batchExecutionContext.initialState();