Skip to content

Commit

Permalink
Add tests for the internal transport action used to flush the cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Apr 10, 2024
1 parent 62e9839 commit f1db7fd
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService;

public class TransportFlushTrainedModelCacheAction extends AcknowledgedTransportMasterNodeAction<FlushTrainedModelCacheAction.Request> {

private final TrainedModelCacheMetadataService modelCacheMetadataService;

@Inject
public TransportFlushTrainedModelCacheAction(
TransportService transportService,
ClusterService clusterService,
ThreadPool threadPool,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver
IndexNameExpressionResolver indexNameExpressionResolver,
TrainedModelCacheMetadataService modelCacheMetadataService
) {
super(
FlushTrainedModelCacheAction.NAME,
Expand All @@ -43,6 +47,7 @@ public TransportFlushTrainedModelCacheAction(
indexNameExpressionResolver,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.modelCacheMetadataService = modelCacheMetadataService;
}

@Override
Expand All @@ -52,7 +57,7 @@ protected void masterOperation(
ClusterState state,
ActionListener<AcknowledgedResponse> listener
) {
// TODO
modelCacheMetadataService.refreshCacheVersion(listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
public class TrainedModelCacheMetadataService implements ClusterStateListener {
private static final Logger LOGGER = LogManager.getLogger(TrainedModelCacheMetadataService.class);
private static final String TASK_QUEUE_NAME = "trained-models-cache-metadata-management";
private final MasterServiceTaskQueue<TrainedModelCacheMetadataUpdateTask> metadataUpdateTaskQueue;
private final MasterServiceTaskQueue<CacheMetadataUpdateTask> metadataUpdateTaskQueue;
private final Client client;
private volatile boolean isMasterNode = false;

public TrainedModelCacheMetadataService(ClusterService clusterService, Client client) {
this.client = new OriginSettingClient(client, ML_ORIGIN);
TrainedModelCacheMetadataUpdateTaskExecutor metadataUpdateTaskExecutor = new TrainedModelCacheMetadataUpdateTaskExecutor();
CacheMetadataUpdateTaskExecutor metadataUpdateTaskExecutor = new CacheMetadataUpdateTaskExecutor();
this.metadataUpdateTaskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, metadataUpdateTaskExecutor);
clusterService.addListener(this);
}
Expand All @@ -50,7 +50,7 @@ public void refreshCacheVersion(ActionListener<AcknowledgedResponse> listener) {
return;
}

TrainedModelCacheMetadataUpdateTask updateMetadataTask = new RefreshTrainedModeCacheMetadataVersionTask(listener);
CacheMetadataUpdateTask updateMetadataTask = new RefreshCacheMetadataVersionTask(listener);
this.metadataUpdateTaskQueue.submitTask(updateMetadataTask.getDescription(), updateMetadataTask, null);
}

Expand All @@ -62,16 +62,16 @@ public void clusterChanged(ClusterChangedEvent event) {
this.isMasterNode = event.localNodeMaster();
}

private abstract static class TrainedModelCacheMetadataUpdateTask implements ClusterStateTaskListener {
private abstract static class CacheMetadataUpdateTask implements ClusterStateTaskListener {
protected final ActionListener<AcknowledgedResponse> listener;

TrainedModelCacheMetadataUpdateTask(ActionListener<AcknowledgedResponse> listener) {
CacheMetadataUpdateTask(ActionListener<AcknowledgedResponse> listener) {
this.listener = listener;
}

protected abstract TrainedModelCacheMetadata execute(
TrainedModelCacheMetadata currentCacheMetadata,
TaskContext<TrainedModelCacheMetadataUpdateTask> taskContext
TaskContext<CacheMetadataUpdateTask> taskContext
);

protected abstract String getDescription();
Expand All @@ -83,15 +83,15 @@ public void onFailure(@Nullable Exception e) {
}
}

private static class RefreshTrainedModeCacheMetadataVersionTask extends TrainedModelCacheMetadataUpdateTask {
RefreshTrainedModeCacheMetadataVersionTask(ActionListener<AcknowledgedResponse> listener) {
private static class RefreshCacheMetadataVersionTask extends CacheMetadataUpdateTask {
RefreshCacheMetadataVersionTask(ActionListener<AcknowledgedResponse> listener) {
super(listener);
}

@Override
protected TrainedModelCacheMetadata execute(
TrainedModelCacheMetadata currentCacheMetadata,
TaskContext<TrainedModelCacheMetadataUpdateTask> taskContext
TaskContext<CacheMetadataUpdateTask> taskContext
) {
long newVersion = currentCacheMetadata.version() < Long.MAX_VALUE ? currentCacheMetadata.version() + 1 : 1L;
taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE));
Expand All @@ -104,11 +104,9 @@ protected String getDescription() {
}
}

private static class TrainedModelCacheMetadataUpdateTaskExecutor
implements
ClusterStateTaskExecutor<TrainedModelCacheMetadataUpdateTask> {
private static class CacheMetadataUpdateTaskExecutor implements ClusterStateTaskExecutor<CacheMetadataUpdateTask> {
@Override
public ClusterState execute(BatchExecutionContext<TrainedModelCacheMetadataUpdateTask> batchExecutionContext) {
public ClusterState execute(BatchExecutionContext<CacheMetadataUpdateTask> batchExecutionContext) {
final var initialState = batchExecutionContext.initialState();
XPackPlugin.checkReadyForXPackCustomMetadata(initialState);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelCacheMetadataService;
import org.junit.Before;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicReference;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TransportFlushTrainedModelCacheActionTests extends ESTestCase {

private ThreadPool threadPool;
private TrainedModelCacheMetadataService modelCacheMetadataService;

@Before
@SuppressWarnings({ "unchecked", "rawtypes" })
private void setupMocks() {
ExecutorService executorService = mock(ExecutorService.class);
threadPool = mock(ThreadPool.class);
org.mockito.Mockito.doAnswer(invocation -> {
invocation.getArgument(0, Runnable.class).run();
return null;
}).when(executorService).execute(any(Runnable.class));
when(threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)).thenReturn(executorService);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));

modelCacheMetadataService = mock(TrainedModelCacheMetadataService.class);
doAnswer(invocationOnMock -> {
ActionListener listener = invocationOnMock.getArgument(0, ActionListener.class);
listener.onResponse(AcknowledgedResponse.TRUE);
return null;
}).when(modelCacheMetadataService).refreshCacheVersion(any(ActionListener.class));
}

public void testOperation() {
ClusterService clusterService = mock(ClusterService.class);
TransportFlushTrainedModelCacheAction action = createAction(clusterService);

ClusterState clusterState = ClusterState.builder(new ClusterName("flush-trained-model-cache-metadata-tests")).build();

FlushTrainedModelCacheAction.Request request = new FlushTrainedModelCacheAction.Request();
AtomicReference<AcknowledgedResponse> ack = new AtomicReference<>();
ActionListener<AcknowledgedResponse> listener = ActionTestUtils.assertNoFailureListener(ack::set);

action.masterOperation(null, request, clusterState, listener);

assertTrue(ack.get().isAcknowledged());
verify(modelCacheMetadataService).refreshCacheVersion(listener);
}

private TransportFlushTrainedModelCacheAction createAction(ClusterService clusterService) {
return new TransportFlushTrainedModelCacheAction(
mock(TransportService.class),
clusterService,
threadPool,
mock(ActionFilters.class),
mock(IndexNameExpressionResolver.class),
modelCacheMetadataService
);
}
}

0 comments on commit f1db7fd

Please sign in to comment.