Skip to content

Commit

Permalink
[7.x][ML] Retry persisting DF Analytics results (elastic#52048)
Browse files Browse the repository at this point in the history
Employs `ResultsPersisterService` from `DataFrameRowsJoiner` in order
to add retries when a data frame analytics job is persisting the results
to the destination data frame.

Backport of elastic#52048
  • Loading branch information
dimitris-athanasiou committed Feb 10, 2020
1 parent 610f681 commit bd0dfdf
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
dataFrameAnalyticsAuditor, trainedModelProvider);
dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService);
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
new MemoryUsageEstimationProcessManager(
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.util.List;
Expand All @@ -62,19 +63,22 @@ public class AnalyticsProcessManager {
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor;
private final TrainedModelProvider trainedModelProvider;
private final ResultsPersisterService resultsPersisterService;

public AnalyticsProcessManager(Client client,
ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
TrainedModelProvider trainedModelProvider,
ResultsPersisterService resultsPersisterService) {
this(
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
analyticsProcessFactory,
auditor,
trainedModelProvider);
trainedModelProvider,
resultsPersisterService);
}

// Visible for testing
Expand All @@ -83,13 +87,15 @@ public AnalyticsProcessManager(Client client,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
TrainedModelProvider trainedModelProvider,
ResultsPersisterService resultsPersisterService) {
this.client = Objects.requireNonNull(client);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
}

public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) {
Expand Down Expand Up @@ -419,7 +425,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataEx
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameRowsJoiner dataFrameRowsJoiner =
new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true));
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.util.Collections;
Expand All @@ -38,16 +35,17 @@ class DataFrameRowsJoiner implements AutoCloseable {
private static final int RESULTS_BATCH_SIZE = 1000;

private final String analyticsId;
private final Client client;
private final DataFrameDataExtractor dataExtractor;
private final ResultsPersisterService resultsPersisterService;
private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator;
private LinkedList<RowResults> currentResults;
private volatile String failure;

DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) {
DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor,
ResultsPersisterService resultsPersisterService) {
this.analyticsId = Objects.requireNonNull(analyticsId);
this.client = Objects.requireNonNull(client);
this.dataExtractor = Objects.requireNonNull(dataExtractor);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
this.dataFrameRowsIterator = new ResultMatchingDataFrameRows();
this.currentResults = new LinkedList<>();
}
Expand Down Expand Up @@ -88,7 +86,8 @@ private void joinCurrentResults() {
bulkRequest.add(createIndexRequest(result, row.getHit()));
}
if (bulkRequest.numberOfActions() > 0) {
executeBulkRequest(bulkRequest);
resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(), bulkRequest, analyticsId, () -> true, errorMsg -> {});
}
currentResults = new LinkedList<>();
}
Expand All @@ -113,14 +112,6 @@ private IndexRequest createIndexRequest(RowResults result, SearchHit hit) {
return indexRequest;
}

private void executeBulkRequest(BulkRequest bulkRequest) {
BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client,
() -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet());
if (bulkResponse.hasFailures()) {
throw ExceptionsHelper.serverError("failures while writing results [" + bulkResponse.buildFailureMessage() + "]");
}
}

@Override
public void close() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
Expand All @@ -27,13 +28,16 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;

import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -95,9 +99,28 @@ public BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler) {
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
providedBulkRequest -> client.bulk(providedBulkRequest).actionGet());
}

public BulkResponse bulkIndexWithHeadersWithRetry(Map<String, String> headers,
BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler) {
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
providedBulkRequest -> ClientHelper.executeWithHeaders(
headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()));
}

private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler,
Function<BulkRequest, BulkResponse> actionExecutor) {
RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler);
while (true) {
BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet();
BulkResponse bulkResponse = actionExecutor.apply(bulkRequest);
if (bulkResponse.hasFailures() == false) {
return bulkResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.Before;
import org.mockito.InOrder;

Expand Down Expand Up @@ -65,6 +66,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
private DataFrameDataExtractorFactory dataExtractorFactory;
private DataFrameDataExtractor dataExtractor;
private ResultsPersisterService resultsPersisterService;
private AnalyticsProcessManager processManager;

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -97,8 +99,10 @@ public void setUpMocks() {
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));

processManager = new AnalyticsProcessManager(
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
resultsPersisterService = mock(ResultsPersisterService.class);

processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
trainedModelProvider, resultsPersisterService);
}

public void testRunJob_TaskIsStopping() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,17 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;

import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.Before;
import org.mockito.ArgumentCaptor;

Expand All @@ -35,7 +30,8 @@
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.same;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -46,19 +42,22 @@ public class DataFrameRowsJoinerTests extends ESTestCase {

private static final String ANALYTICS_ID = "my_analytics";

private Client client;
private static final Map<String, String> HEADERS = Collections.singletonMap("foo", "bar");

private DataFrameDataExtractor dataExtractor;
private ResultsPersisterService resultsPersisterService;
private ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

@Before
public void setUpMocks() {
client = mock(Client.class);
dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.getHeaders()).thenReturn(HEADERS);
resultsPersisterService = mock(ResultsPersisterService.class);
}

public void testProcess_GivenNoResults() {
givenProcessResults(Collections.emptyList());
verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
}

public void testProcess_GivenSingleRowAndResult() throws IOException {
Expand Down Expand Up @@ -126,7 +125,7 @@ public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IO
RowResults result = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result));

verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
}

public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException {
Expand Down Expand Up @@ -204,7 +203,7 @@ public void testProcess_GivenMoreResultsThanRows() throws IOException {
RowResults result2 = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result1, result2));

verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
}

public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
Expand All @@ -218,13 +217,13 @@ public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws

givenProcessResults(Collections.emptyList());

verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
verify(dataExtractor).cancel();
verify(dataExtractor, times(2)).next();
}

private void givenProcessResults(List<RowResults> results) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) {
results.forEach(joiner::processRowResults);
}
}
Expand All @@ -251,14 +250,9 @@ private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values,
}

private void givenClientHasNoFailures() {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(threadContext);
@SuppressWarnings("unchecked")
ActionFuture<BulkResponse> responseFuture = mock(ActionFuture.class);
when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture);
when(client.threadPool()).thenReturn(threadPool);
when(resultsPersisterService.bulkIndexWithHeadersWithRetry(
eq(HEADERS), bulkRequestCaptor.capture(), eq(ANALYTICS_ID), any(), any()))
.thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
}

private static class DelegateStubDataExtractor {
Expand Down

0 comments on commit bd0dfdf

Please sign in to comment.