diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java index a7b2bbfadd2b5..174823e4a738e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheAction.java @@ -22,16 +22,20 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; public class TransportFlushTrainedModelCacheAction extends AcknowledgedTransportMasterNodeAction { + private final TrainedModelCacheMetadataService modelCacheMetadataService; + @Inject public TransportFlushTrainedModelCacheAction( TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelCacheMetadataService modelCacheMetadataService ) { super( FlushTrainedModelCacheAction.NAME, @@ -43,6 +47,7 @@ public TransportFlushTrainedModelCacheAction( indexNameExpressionResolver, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.modelCacheMetadataService = modelCacheMetadataService; } @Override @@ -52,7 +57,7 @@ protected void masterOperation( ClusterState state, ActionListener listener ) { - // TODO + modelCacheMetadataService.refreshCacheVersion(listener); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java index 3a836684c0c83..80205b30c90d3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelCacheMetadataService.java @@ -33,13 +33,13 @@ public class TrainedModelCacheMetadataService implements ClusterStateListener { private static final Logger LOGGER = LogManager.getLogger(TrainedModelCacheMetadataService.class); private static final String TASK_QUEUE_NAME = "trained-models-cache-metadata-management"; - private final MasterServiceTaskQueue metadataUpdateTaskQueue; + private final MasterServiceTaskQueue metadataUpdateTaskQueue; private final Client client; private volatile boolean isMasterNode = false; public TrainedModelCacheMetadataService(ClusterService clusterService, Client client) { this.client = new OriginSettingClient(client, ML_ORIGIN); - TrainedModelCacheMetadataUpdateTaskExecutor metadataUpdateTaskExecutor = new TrainedModelCacheMetadataUpdateTaskExecutor(); + CacheMetadataUpdateTaskExecutor metadataUpdateTaskExecutor = new CacheMetadataUpdateTaskExecutor(); this.metadataUpdateTaskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, metadataUpdateTaskExecutor); clusterService.addListener(this); } @@ -50,7 +50,7 @@ public void refreshCacheVersion(ActionListener listener) { return; } - TrainedModelCacheMetadataUpdateTask updateMetadataTask = new RefreshTrainedModeCacheMetadataVersionTask(listener); + CacheMetadataUpdateTask updateMetadataTask = new RefreshCacheMetadataVersionTask(listener); this.metadataUpdateTaskQueue.submitTask(updateMetadataTask.getDescription(), updateMetadataTask, null); } @@ -62,16 +62,16 @@ public void clusterChanged(ClusterChangedEvent event) { this.isMasterNode = event.localNodeMaster(); } - private abstract static class TrainedModelCacheMetadataUpdateTask implements ClusterStateTaskListener { + private abstract static class CacheMetadataUpdateTask implements ClusterStateTaskListener { protected final ActionListener listener; - TrainedModelCacheMetadataUpdateTask(ActionListener listener) { + CacheMetadataUpdateTask(ActionListener listener) { this.listener = listener; } protected abstract TrainedModelCacheMetadata execute( TrainedModelCacheMetadata currentCacheMetadata, - TaskContext taskContext + TaskContext taskContext ); protected abstract String getDescription(); @@ -83,15 +83,15 @@ public void onFailure(@Nullable Exception e) { } } - private static class RefreshTrainedModeCacheMetadataVersionTask extends TrainedModelCacheMetadataUpdateTask { - RefreshTrainedModeCacheMetadataVersionTask(ActionListener listener) { + private static class RefreshCacheMetadataVersionTask extends CacheMetadataUpdateTask { + RefreshCacheMetadataVersionTask(ActionListener listener) { super(listener); } @Override protected TrainedModelCacheMetadata execute( TrainedModelCacheMetadata currentCacheMetadata, - TaskContext taskContext + TaskContext taskContext ) { long newVersion = currentCacheMetadata.version() < Long.MAX_VALUE ? currentCacheMetadata.version() + 1 : 1L; taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE)); @@ -104,11 +104,9 @@ protected String getDescription() { } } - private static class TrainedModelCacheMetadataUpdateTaskExecutor - implements - ClusterStateTaskExecutor { + private static class CacheMetadataUpdateTaskExecutor implements ClusterStateTaskExecutor { @Override - public ClusterState execute(BatchExecutionContext batchExecutionContext) { + public ClusterState execute(BatchExecutionContext batchExecutionContext) { final var initialState = batchExecutionContext.initialState(); XPackPlugin.checkReadyForXPackCustomMetadata(initialState); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java new file mode 100644 index 0000000000000..2cdf9aebef6d5 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportFlushTrainedModelCacheActionTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService; +import org.junit.Before; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportFlushTrainedModelCacheActionTests extends ESTestCase { + + private ThreadPool threadPool; + private TrainedModelCacheMetadataService modelCacheMetadataService; + + @Before + @SuppressWarnings({ "unchecked", "rawtypes" }) + private void setupMocks() { + ExecutorService executorService = mock(ExecutorService.class); + threadPool = mock(ThreadPool.class); + org.mockito.Mockito.doAnswer(invocation -> { + invocation.getArgument(0, Runnable.class).run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + when(threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + + modelCacheMetadataService = mock(TrainedModelCacheMetadataService.class); + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(0, ActionListener.class); + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(modelCacheMetadataService).refreshCacheVersion(any(ActionListener.class)); + } + + public void testOperation() { + ClusterService clusterService = mock(ClusterService.class); + TransportFlushTrainedModelCacheAction action = createAction(clusterService); + + ClusterState clusterState = ClusterState.builder(new ClusterName("flush-trained-model-cache-metadata-tests")).build(); + + FlushTrainedModelCacheAction.Request request = new FlushTrainedModelCacheAction.Request(); + AtomicReference ack = new AtomicReference<>(); + ActionListener listener = ActionTestUtils.assertNoFailureListener(ack::set); + + action.masterOperation(null, request, clusterState, listener); + + assertTrue(ack.get().isAcknowledged()); + verify(modelCacheMetadataService).refreshCacheVersion(listener); + } + + private TransportFlushTrainedModelCacheAction createAction(ClusterService clusterService) { + return new TransportFlushTrainedModelCacheAction( + mock(TransportService.class), + clusterService, + threadPool, + mock(ActionFilters.class), + mock(IndexNameExpressionResolver.class), + modelCacheMetadataService + ); + } +}