From f4167c2e956532d6c1790658126901ec1945e9cd Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 7 Feb 2020 16:11:32 +0200 Subject: [PATCH] [ML] Retry persisting DF Analytics results Employs `ResultsPersisterService` from `DataFrameRowsJoiner` in order to add retries when a data frame analytics job is persisting the results to the destination data frame. --- .../xpack/ml/MachineLearning.java | 2 +- .../process/AnalyticsProcessManager.java | 14 +++++-- .../process/DataFrameRowsJoiner.java | 23 ++++------- .../persistence/ResultsPersisterService.java | 25 +++++++++++- .../process/AnalyticsProcessManagerTests.java | 8 +++- .../process/DataFrameRowsJoinerTests.java | 38 ++++++++----------- 6 files changed, 64 insertions(+), 46 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 5469da943f272..d50899ea78ebb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -635,7 +635,7 @@ public Collection createComponents(Client client, ClusterService cluster // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, - dataFrameAnalyticsAuditor, trainedModelProvider); + dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService); MemoryUsageEstimationProcessManager memoryEstimationProcessManager = new MemoryUsageEstimationProcessManager( threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index b9eaa6ff407f1..d66973bb7777f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.io.IOException; import java.util.List; @@ -62,19 +63,22 @@ public class AnalyticsProcessManager { private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); private final DataFrameAnalyticsAuditor auditor; private final TrainedModelProvider trainedModelProvider; + private final ResultsPersisterService resultsPersisterService; public AnalyticsProcessManager(Client client, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, - TrainedModelProvider trainedModelProvider) { + TrainedModelProvider trainedModelProvider, + ResultsPersisterService resultsPersisterService) { this( client, threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), analyticsProcessFactory, auditor, - trainedModelProvider); + trainedModelProvider, + resultsPersisterService); } // Visible for testing @@ -83,13 +87,15 @@ public AnalyticsProcessManager(Client client, ExecutorService executorServiceForProcess, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, - TrainedModelProvider trainedModelProvider) { + TrainedModelProvider trainedModelProvider, + ResultsPersisterService resultsPersisterService) { this.client = Objects.requireNonNull(client); this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob); this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess); this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.auditor = Objects.requireNonNull(auditor); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); } public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) { @@ -419,7 +425,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataEx private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task, DataFrameDataExtractorFactory dataExtractorFactory) { DataFrameRowsJoiner dataFrameRowsJoiner = - new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); + new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); return new AnalyticsResultProcessor( config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames()); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java index dbbb7f3cf2313..93c28c8a8d32b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -9,17 +9,14 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkRequest; -import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.io.IOException; import java.util.Collections; @@ -38,16 +35,17 @@ class DataFrameRowsJoiner implements AutoCloseable { private static final int RESULTS_BATCH_SIZE = 1000; private final String analyticsId; - private final Client client; private final DataFrameDataExtractor dataExtractor; + private final ResultsPersisterService resultsPersisterService; private final Iterator dataFrameRowsIterator; private LinkedList currentResults; private volatile String failure; - DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { + DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor, + ResultsPersisterService resultsPersisterService) { this.analyticsId = Objects.requireNonNull(analyticsId); - this.client = Objects.requireNonNull(client); this.dataExtractor = Objects.requireNonNull(dataExtractor); + this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); this.dataFrameRowsIterator = new ResultMatchingDataFrameRows(); this.currentResults = new LinkedList<>(); } @@ -88,7 +86,8 @@ private void joinCurrentResults() { bulkRequest.add(createIndexRequest(result, row.getHit())); } if (bulkRequest.numberOfActions() > 0) { - executeBulkRequest(bulkRequest); + resultsPersisterService.bulkIndexWithHeadersWithRetry( + dataExtractor.getHeaders(), bulkRequest, analyticsId, () -> true, errorMsg -> {}); } currentResults = new LinkedList<>(); } @@ -113,14 +112,6 @@ private IndexRequest createIndexRequest(RowResults result, SearchHit hit) { return indexRequest; } - private void executeBulkRequest(BulkRequest bulkRequest) { - BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client, - () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()); - if (bulkResponse.hasFailures()) { - throw ExceptionsHelper.serverError("failures while writing results [" + bulkResponse.buildFailureMessage() + "]"); - } - } - @Override public void close() { try { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java index a775342b880b4..5ae76f9491861 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; @@ -27,13 +28,16 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ClientHelper; import java.io.IOException; import java.time.Duration; import java.util.Arrays; +import java.util.Map; import java.util.Random; import java.util.Set; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -95,9 +99,28 @@ public BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest, String jobId, Supplier shouldRetry, Consumer msgHandler) { + return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler, + providedBulkRequest -> client.bulk(providedBulkRequest).actionGet()); + } + + public BulkResponse bulkIndexWithHeadersWithRetry(Map headers, + BulkRequest bulkRequest, + String jobId, + Supplier shouldRetry, + Consumer msgHandler) { + return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler, + providedBulkRequest -> ClientHelper.executeWithHeaders( + headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet())); + } + + private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest, + String jobId, + Supplier shouldRetry, + Consumer msgHandler, + Function actionExecutor) { RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler); while (true) { - BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet(); + BulkResponse bulkResponse = actionExecutor.apply(bulkRequest); if (bulkResponse.hasFailures() == false) { return bulkResponse; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 915d6c29efb4d..ae1f7f4bcbade 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import org.junit.Before; import org.mockito.InOrder; @@ -64,6 +65,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { private DataFrameAnalyticsConfig dataFrameAnalyticsConfig; private DataFrameDataExtractorFactory dataExtractorFactory; private DataFrameDataExtractor dataExtractor; + private ResultsPersisterService resultsPersisterService; private AnalyticsProcessManager processManager; @SuppressWarnings("unchecked") @@ -94,8 +96,10 @@ public void setUpMocks() { when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); - processManager = new AnalyticsProcessManager( - client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider); + resultsPersisterService = mock(ResultsPersisterService.class); + + processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, + trainedModelProvider, resultsPersisterService); } public void testRunJob_TaskIsStopping() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java index 6a4230acf643a..7b06a447e7893 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -5,21 +5,16 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; -import org.elasticsearch.action.ActionFuture; -import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.client.Client; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -34,7 +29,8 @@ import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Matchers.same; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -45,19 +41,22 @@ public class DataFrameRowsJoinerTests extends ESTestCase { private static final String ANALYTICS_ID = "my_analytics"; - private Client client; + private static final Map HEADERS = Collections.singletonMap("foo", "bar"); + private DataFrameDataExtractor dataExtractor; + private ResultsPersisterService resultsPersisterService; private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); @Before public void setUpMocks() { - client = mock(Client.class); dataExtractor = mock(DataFrameDataExtractor.class); + when(dataExtractor.getHeaders()).thenReturn(HEADERS); + resultsPersisterService = mock(ResultsPersisterService.class); } public void testProcess_GivenNoResults() { givenProcessResults(Collections.emptyList()); - verifyNoMoreInteractions(client); + verifyNoMoreInteractions(resultsPersisterService); } public void testProcess_GivenSingleRowAndResult() throws IOException { @@ -125,7 +124,7 @@ public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IO RowResults result = new RowResults(2, resultFields); givenProcessResults(Arrays.asList(result)); - verifyNoMoreInteractions(client); + verifyNoMoreInteractions(resultsPersisterService); } public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException { @@ -203,7 +202,7 @@ public void testProcess_GivenMoreResultsThanRows() throws IOException { RowResults result2 = new RowResults(2, resultFields); givenProcessResults(Arrays.asList(result1, result2)); - verifyNoMoreInteractions(client); + verifyNoMoreInteractions(resultsPersisterService); } public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException { @@ -217,13 +216,13 @@ public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws givenProcessResults(Collections.emptyList()); - verifyNoMoreInteractions(client); + verifyNoMoreInteractions(resultsPersisterService); verify(dataExtractor).cancel(); verify(dataExtractor, times(2)).next(); } private void givenProcessResults(List results) { - try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) { + try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) { results.forEach(joiner::processRowResults); } } @@ -250,14 +249,9 @@ private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, } private void givenClientHasNoFailures() { - ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.getThreadContext()).thenReturn(threadContext); - @SuppressWarnings("unchecked") - ActionFuture responseFuture = mock(ActionFuture.class); - when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); - when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); - when(client.threadPool()).thenReturn(threadPool); + when(resultsPersisterService.bulkIndexWithHeadersWithRetry( + eq(HEADERS), bulkRequestCaptor.capture(), eq(ANALYTICS_ID), any(), any())) + .thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); } private static class DelegateStubDataExtractor {