Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Retry persisting DF Analytics results #52048

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
dataFrameAnalyticsAuditor, trainedModelProvider);
dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService);
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
new MemoryUsageEstimationProcessManager(
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.util.List;
Expand All @@ -62,19 +63,22 @@ public class AnalyticsProcessManager {
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor;
private final TrainedModelProvider trainedModelProvider;
private final ResultsPersisterService resultsPersisterService;

public AnalyticsProcessManager(Client client,
ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
TrainedModelProvider trainedModelProvider,
ResultsPersisterService resultsPersisterService) {
this(
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
analyticsProcessFactory,
auditor,
trainedModelProvider);
trainedModelProvider,
resultsPersisterService);
}

// Visible for testing
Expand All @@ -83,13 +87,15 @@ public AnalyticsProcessManager(Client client,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
TrainedModelProvider trainedModelProvider,
ResultsPersisterService resultsPersisterService) {
this.client = Objects.requireNonNull(client);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
}

public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) {
Expand Down Expand Up @@ -419,7 +425,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataEx
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameRowsJoiner dataFrameRowsJoiner =
new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true));
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.util.Collections;
Expand All @@ -38,16 +35,17 @@ class DataFrameRowsJoiner implements AutoCloseable {
private static final int RESULTS_BATCH_SIZE = 1000;

private final String analyticsId;
private final Client client;
private final DataFrameDataExtractor dataExtractor;
private final ResultsPersisterService resultsPersisterService;
private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator;
private LinkedList<RowResults> currentResults;
private volatile String failure;

DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) {
DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor,
ResultsPersisterService resultsPersisterService) {
this.analyticsId = Objects.requireNonNull(analyticsId);
this.client = Objects.requireNonNull(client);
this.dataExtractor = Objects.requireNonNull(dataExtractor);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
this.dataFrameRowsIterator = new ResultMatchingDataFrameRows();
this.currentResults = new LinkedList<>();
}
Expand Down Expand Up @@ -88,7 +86,8 @@ private void joinCurrentResults() {
bulkRequest.add(createIndexRequest(result, row.getHit()));
}
if (bulkRequest.numberOfActions() > 0) {
executeBulkRequest(bulkRequest);
resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(), bulkRequest, analyticsId, () -> true, errorMsg -> {});
}
currentResults = new LinkedList<>();
}
Expand All @@ -113,14 +112,6 @@ private IndexRequest createIndexRequest(RowResults result, SearchHit hit) {
return indexRequest;
}

private void executeBulkRequest(BulkRequest bulkRequest) {
BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client,
() -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet());
if (bulkResponse.hasFailures()) {
throw ExceptionsHelper.serverError("failures while writing results [" + bulkResponse.buildFailureMessage() + "]");
}
}

@Override
public void close() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
Expand All @@ -27,13 +28,16 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;

import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -95,9 +99,28 @@ public BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler) {
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
providedBulkRequest -> client.bulk(providedBulkRequest).actionGet());
}

public BulkResponse bulkIndexWithHeadersWithRetry(Map<String, String> headers,
BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler) {
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
providedBulkRequest -> ClientHelper.executeWithHeaders(
headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()));
}

private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler,
Function<BulkRequest, BulkResponse> actionExecutor) {
RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler);
while (true) {
BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet();
BulkResponse bulkResponse = actionExecutor.apply(bulkRequest);
if (bulkResponse.hasFailures() == false) {
return bulkResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.Before;
import org.mockito.InOrder;

Expand Down Expand Up @@ -64,6 +65,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
private DataFrameDataExtractorFactory dataExtractorFactory;
private DataFrameDataExtractor dataExtractor;
private ResultsPersisterService resultsPersisterService;
private AnalyticsProcessManager processManager;

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -94,8 +96,10 @@ public void setUpMocks() {
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));

processManager = new AnalyticsProcessManager(
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
resultsPersisterService = mock(ResultsPersisterService.class);

processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
trainedModelProvider, resultsPersisterService);
}

public void testRunJob_TaskIsStopping() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;

import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.Before;
import org.mockito.ArgumentCaptor;

Expand All @@ -34,7 +29,8 @@
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.same;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -45,19 +41,22 @@ public class DataFrameRowsJoinerTests extends ESTestCase {

private static final String ANALYTICS_ID = "my_analytics";

private Client client;
private static final Map<String, String> HEADERS = Collections.singletonMap("foo", "bar");

private DataFrameDataExtractor dataExtractor;
private ResultsPersisterService resultsPersisterService;
private ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

@Before
public void setUpMocks() {
client = mock(Client.class);
dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.getHeaders()).thenReturn(HEADERS);
resultsPersisterService = mock(ResultsPersisterService.class);
}

public void testProcess_GivenNoResults() {
givenProcessResults(Collections.emptyList());
verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
}

public void testProcess_GivenSingleRowAndResult() throws IOException {
Expand Down Expand Up @@ -125,7 +124,7 @@ public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IO
RowResults result = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result));

verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
}

public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException {
Expand Down Expand Up @@ -203,7 +202,7 @@ public void testProcess_GivenMoreResultsThanRows() throws IOException {
RowResults result2 = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result1, result2));

verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
}

public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
Expand All @@ -217,13 +216,13 @@ public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws

givenProcessResults(Collections.emptyList());

verifyNoMoreInteractions(client);
verifyNoMoreInteractions(resultsPersisterService);
verify(dataExtractor).cancel();
verify(dataExtractor, times(2)).next();
}

private void givenProcessResults(List<RowResults> results) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) {
results.forEach(joiner::processRowResults);
}
}
Expand All @@ -250,14 +249,9 @@ private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values,
}

private void givenClientHasNoFailures() {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(threadContext);
@SuppressWarnings("unchecked")
ActionFuture<BulkResponse> responseFuture = mock(ActionFuture.class);
when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture);
when(client.threadPool()).thenReturn(threadPool);
when(resultsPersisterService.bulkIndexWithHeadersWithRetry(
eq(HEADERS), bulkRequestCaptor.capture(), eq(ANALYTICS_ID), any(), any()))
.thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
}

private static class DelegateStubDataExtractor {
Expand Down