Skip to content

Commit

Permalink
Add more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Apr 12, 2024
1 parent 706bf1e commit ad3c97b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
import java.util.Objects;

public class TrainedModelCacheMetadata extends AbstractNamedDiffable<Metadata.Custom> implements Metadata.Custom {

public static final String NAME = "trained_model_cache_metadata";

public static final TrainedModelCacheMetadata EMPTY = new TrainedModelCacheMetadata(0L);
private static final ParseField VERSION_FIELD = new ParseField("version");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public void testRefreshCacheVersionOnMasterNode() {
final ActionListener<AcknowledgedResponse> listener = mock(ActionListener.class);
modelCacheMetadataService.refreshCacheVersion(listener);

// Verify a cluster state update task were submitted.
ArgumentCaptor<CacheMetadataUpdateTask> updateTaskCaptor = ArgumentCaptor.forClass(RefreshCacheMetadataVersionTask.class);
verify(taskQueue).submitTask(any(String.class), updateTaskCaptor.capture(), isNull());
assertThat(updateTaskCaptor.getValue().listener, is(listener));
Expand All @@ -104,13 +105,15 @@ public void testRefreshCacheVersionOnNonMasterNode() {
final ActionListener<AcknowledgedResponse> listener = mock(ActionListener.class);
modelCacheMetadataService.refreshCacheVersion(listener);

// Check a FlushTrainedModelCacheAction request is emitted to the master node, that will flush the cache.
verify(client).execute(
eq(FlushTrainedModelCacheAction.INSTANCE),
any(FlushTrainedModelCacheAction.Request.class),
any(ActionListener.class)
);
verify(listener).onResponse(eq(AcknowledgedResponse.TRUE));

// Verify no cluster state update task were submitted on a non-master node.
verify(taskQueue, never()).submitTask(any(String.class), any(RefreshCacheMetadataVersionTask.class), any(TimeValue.class));
}

Expand Down Expand Up @@ -150,6 +153,7 @@ public void testRefreshCacheMetadataVersionTaskExecutionWithMaxVersion() {
final TaskContext<CacheMetadataUpdateTask> taskContext = mock(TaskContext.class);
final TrainedModelCacheMetadata updatedCacheMetadata = task.execute(currentCacheMetadata, taskContext);

// Check the version counter is reset to 1
assertThat(updatedCacheMetadata.version(), equalTo( 1L));
}

Expand Down

0 comments on commit ad3c97b

Please sign in to comment.