Skip to content

Commit

Permalink
[ML] Fix tests randomly failing on CI (#51142) (#51150)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Jan 17, 2020
1 parent b70ebde commit da73c91
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {

public static Response randomResponse() {
int listSize = randomInt(10);
public static Response randomResponse(int listSize) {
List<Response.Stats> analytics = new ArrayList<>(listSize);
for (int j = 0; j < listSize; j++) {
String failureReason = randomBoolean() ? null : randomAlphaOfLength(10);
Expand All @@ -37,7 +36,7 @@ public static Response randomResponse() {

@Override
protected Response createTestInstance() {
return randomResponse();
return randomResponse(randomInt(10));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ public void markAsCompleted() {
isMarkAsCompletedCalled = true;
}

persistProgress(() -> super.markAsCompleted());
persistProgress(client, taskParams.getId(), () -> super.markAsCompleted());
}

@Override
public void markAsFailed(Exception e) {
persistProgress(() -> super.markAsFailed(e));
persistProgress(client, taskParams.getId(), () -> super.markAsFailed(e));
}

public void stop(String reason, TimeValue timeout) {
Expand Down Expand Up @@ -244,21 +244,22 @@ private TaskId getReindexTaskId() {
}
}

private void persistProgress(Runnable runnable) {
LOGGER.debug("[{}] Persisting progress", taskParams.getId());
// Visible for testing
static void persistProgress(Client client, String jobId, Runnable runnable) {
LOGGER.debug("[{}] Persisting progress", jobId);

String progressDocId = StoredProgress.documentId(taskParams.getId());
String progressDocId = StoredProgress.documentId(jobId);
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();

// Step 4: Run the runnable provided as the argument
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
runnable.run();
},
indexError -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
"[{}] cannot persist progress as an error occurred while indexing", jobId), indexError);
runnable.run();
}
);
Expand All @@ -283,7 +284,7 @@ private void persistProgress(Runnable runnable) {
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving former progress document", taskParams.getId()), e);
"[{}] cannot persist progress as an error occurred while retrieving former progress document", jobId), e);
runnable.run();
}
);
Expand All @@ -302,13 +303,13 @@ private void persistProgress(Runnable runnable) {
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
runnable.run();
}
);

// Step 1: Fetch progress to be persisted
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
*/
package org.elasticsearch.xpack.ml.dataframe;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.search.SearchHit;
Expand All @@ -22,10 +20,8 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.TaskParams;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.stubbing.Answer;
Expand Down Expand Up @@ -115,13 +111,13 @@ public void testDetermineStartingState_GivenEmptyProgress() {
assertThat(startingState, equalTo(StartingState.FINISHED));
}

private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAlias) {
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(client.threadPool()).thenReturn(threadPool);

GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse();
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());

SearchResponse searchResponse = mock(SearchResponse.class);
Expand All @@ -131,40 +127,30 @@ private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAl
IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());

TaskParams taskParams = new TaskParams("task_id", Version.CURRENT, Collections.emptyList(), false);
DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
0,
"",
"",
null,
null,
client,
mock(ClusterService.class),
mock(DataFrameAnalyticsManager.class),
mock(DataFrameAnalyticsAuditor.class),
taskParams);
task.markAsCompleted();
Runnable runnable = mock(Runnable.class);

DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);

ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);

InOrder inOrder = inOrder(client);
InOrder inOrder = inOrder(client, runnable);
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
inOrder.verify(runnable).run();
inOrder.verifyNoMoreInteractions();

IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
}

public void testMarkAsCompleted_ProgressDocumentCreated() {
testMarkAsCompleted(SearchHits.empty(), ".ml-state-write");
public void testPersistProgress_ProgressDocumentCreated() {
testPersistProgress(SearchHits.empty(), ".ml-state-write");
}

public void testMarkAsCompleted_ProgressDocumentUpdated() {
testMarkAsCompleted(
public void testPersistProgress_ProgressDocumentUpdated() {
testPersistProgress(
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
".ml-state-dummy");
}
Expand Down

0 comments on commit da73c91

Please sign in to comment.