Skip to content

Commit

Permalink
Preserve parent task id for data frame analytics (#55046)
Browse files Browse the repository at this point in the history
This change makes sure that all internal client requests spawned by the
data frame analytics persistent task executor and that use the end user
security credentials, have the parent task id assigned. The objective here
is to permit auditing (as well as tracking for debugging purposes) of all
the end-user requests executed on its behalf by persistent tasks.
Because data frame analytics taks already implements graceful shutdown
of child tasks, this change does not interfere with it by opting out of
the persistent task cancellation of child tasks.

Relates #54943 #52314
  • Loading branch information
albertzaharovits authored Apr 10, 2020
1 parent 16e9433 commit f7809dd
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ public static String configIndexName() {
* Creates the .ml-state-000001 index (if necessary)
* Creates the .ml-state-write alias for the .ml-state-000001 index (if necessary)
*/
public static void createStateIndexAndAliasIfNecessary(Client client, ClusterState state, IndexNameExpressionResolver resolver,
public static void createStateIndexAndAliasIfNecessary(Client client, ClusterState state,
IndexNameExpressionResolver resolver,
final ActionListener<Boolean> finalListener) {
MlIndexAndAlias.createIndexAndAliasIfNecessary(
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -83,7 +84,8 @@ protected void doExecute(Task task,

private void explain(Task task, PutDataFrameAnalyticsAction.Request request,
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(client);
ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory =
new ExtractedFieldsDetectorFactory(new ParentTaskAssigningClient(client, task.getParentTaskId()));
extractedFieldsDetectorFactory.createFromSource(
request.getConfig(),
ActionListener.wrap(
Expand Down Expand Up @@ -115,7 +117,7 @@ private void estimateMemoryUsage(Task task,
ActionListener<MemoryEstimation> listener) {
final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId();
DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices(
client, estimateMemoryTaskId, request.getConfig(), extractedFields);
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, request.getConfig(), extractedFields);
processManager.runJobAsync(
estimateMemoryTaskId,
request.getConfig(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
Expand Down Expand Up @@ -200,7 +201,7 @@ public void onFailure(Exception e) {
);

// Get start context
getStartContext(request.getId(), startContextListener);
getStartContext(request.getId(), task, startContextListener);
}

private void estimateMemoryUsageAndUpdateMemoryTracker(StartContext startContext, ActionListener<StartContext> listener) {
Expand Down Expand Up @@ -240,8 +241,9 @@ private void estimateMemoryUsageAndUpdateMemoryTracker(StartContext startContext

}

private void getStartContext(String id, ActionListener<StartContext> finalListener) {
private void getStartContext(String id, Task task, ActionListener<StartContext> finalListener) {

ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
// Step 7. Validate that there are analyzable data in the source index
ActionListener<StartContext> validateMappingsMergeListener = ActionListener.wrap(
startContext -> validateSourceIndexHasRows(startContext, finalListener),
Expand All @@ -250,7 +252,7 @@ private void getStartContext(String id, ActionListener<StartContext> finalListen

// Step 6. Validate mappings can be merged
ActionListener<StartContext> toValidateMappingsListener = ActionListener.wrap(
startContext -> MappingsMerger.mergeMappings(client, startContext.config.getHeaders(),
startContext -> MappingsMerger.mergeMappings(parentTaskClient, startContext.config.getHeaders(),
startContext.config.getSource(), ActionListener.wrap(
mappings -> validateMappingsMergeListener.onResponse(startContext), finalListener::onFailure)),
finalListener::onFailure
Expand All @@ -261,7 +263,7 @@ private void getStartContext(String id, ActionListener<StartContext> finalListen
startContext -> {
switch (startContext.startingState) {
case FIRST_TIME:
checkDestIndexIsEmptyIfExists(startContext, toValidateMappingsListener);
checkDestIndexIsEmptyIfExists(parentTaskClient, startContext, toValidateMappingsListener);
break;
case RESUMING_REINDEXING:
case RESUMING_ANALYZING:
Expand All @@ -283,7 +285,7 @@ private void getStartContext(String id, ActionListener<StartContext> finalListen
// Step 4. Check data extraction is possible
ActionListener<StartContext> toValidateExtractionPossibleListener = ActionListener.wrap(
startContext -> {
new ExtractedFieldsDetectorFactory(client).createFromSource(startContext.config, ActionListener.wrap(
new ExtractedFieldsDetectorFactory(parentTaskClient).createFromSource(startContext.config, ActionListener.wrap(
extractedFieldsDetector -> {
startContext.extractedFields = extractedFieldsDetector.detect().v1();
toValidateDestEmptyListener.onResponse(startContext);
Expand Down Expand Up @@ -361,13 +363,14 @@ private void getProgress(DataFrameAnalyticsConfig config, ActionListener<List<Ph
));
}

private void checkDestIndexIsEmptyIfExists(StartContext startContext, ActionListener<StartContext> listener) {
private void checkDestIndexIsEmptyIfExists(ParentTaskAssigningClient parentTaskClient, StartContext startContext,
ActionListener<StartContext> listener) {
String destIndex = startContext.config.getDest().getIndex();
SearchRequest destEmptySearch = new SearchRequest(destIndex);
destEmptySearch.source().size(0);
destEmptySearch.allowPartialSearchResults(false);
ClientHelper.executeWithHeadersAsync(startContext.config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE,
destEmptySearch, ActionListener.wrap(
ClientHelper.executeWithHeadersAsync(startContext.config.getHeaders(), ClientHelper.ML_ORIGIN, parentTaskClient,
SearchAction.INSTANCE, destEmptySearch, ActionListener.wrap(
searchResponse -> {
if (searchResponse.getHits().getTotalHits().value > 0) {
listener.onFailure(ExceptionsHelper.badRequestException("dest index [{}] must be empty", destIndex));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down Expand Up @@ -113,15 +115,16 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current

// Make sure the stats index and alias exist
ActionListener<Boolean> stateAliasListener = ActionListener.wrap(
aBoolean -> createStatsIndexAndUpdateMappingsIfNecessary(clusterState, statsIndexListener),
configListener::onFailure
aBoolean -> createStatsIndexAndUpdateMappingsIfNecessary(new ParentTaskAssigningClient(client, task.getParentTaskId()),
clusterState, statsIndexListener), configListener::onFailure
);

// Make sure the state index and alias exist
AnomalyDetectorsIndex.createStateIndexAndAliasIfNecessary(client, clusterState, expressionResolver, stateAliasListener);
AnomalyDetectorsIndex.createStateIndexAndAliasIfNecessary(new ParentTaskAssigningClient(client, task.getParentTaskId()),
clusterState, expressionResolver, stateAliasListener);
}

private void createStatsIndexAndUpdateMappingsIfNecessary(ClusterState clusterState, ActionListener<Boolean> listener) {
private void createStatsIndexAndUpdateMappingsIfNecessary(Client client, ClusterState clusterState, ActionListener<Boolean> listener) {
ActionListener<Boolean> createIndexListener = ActionListener.wrap(
aBoolean -> ElasticsearchMappings.addDocMappingIfMissing(
MlStatsIndex.writeAlias(),
Expand Down Expand Up @@ -175,7 +178,7 @@ private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFra
task.markAsCompleted();
return;
}
ClientHelper.executeAsyncWithOrigin(client,
ClientHelper.executeAsyncWithOrigin(new ParentTaskAssigningClient(client, task.getParentTaskId()),
ML_ORIGIN,
DeleteIndexAction.INSTANCE,
new DeleteIndexRequest(config.getDest().getIndex()),
Expand All @@ -200,6 +203,8 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
return;
}

final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());

// Reindexing is complete; start analytics
ActionListener<BulkByScrollResponse> reindexCompletedListener = ActionListener.wrap(
reindexResponse -> {
Expand Down Expand Up @@ -239,8 +244,9 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
reindexRequest.getSearchRequest().source().fetchSource(config.getSource().getSourceFiltering());
reindexRequest.setDestIndex(config.getDest().getIndex());
reindexRequest.setScript(new Script("ctx._source." + DestinationIndex.ID_COPY + " = ctx._id"));
reindexRequest.setParentTask(task.getParentTaskId());

final ThreadContext threadContext = client.threadPool().getThreadContext();
final ThreadContext threadContext = parentTaskClient.threadPool().getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(ML_ORIGIN)) {
LOGGER.info("[{}] Started reindexing", config.getId());
Expand All @@ -261,7 +267,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
config.getId(),
Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_REUSING_DEST_INDEX, indexResponse.indices()[0]));
LOGGER.info("[{}] Using existing destination index [{}]", config.getId(), indexResponse.indices()[0]);
DestinationIndex.updateMappingsToDestIndex(client, config, indexResponse, ActionListener.wrap(
DestinationIndex.updateMappingsToDestIndex(parentTaskClient, config, indexResponse, ActionListener.wrap(
acknowledgedResponse -> copyIndexCreatedListener.onResponse(null),
copyIndexCreatedListener::onFailure
));
Expand All @@ -272,14 +278,14 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
config.getId(),
Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_CREATING_DEST_INDEX, config.getDest().getIndex()));
LOGGER.info("[{}] Creating destination index [{}]", config.getId(), config.getDest().getIndex());
DestinationIndex.createDestinationIndex(client, Clock.systemUTC(), config, copyIndexCreatedListener);
DestinationIndex.createDestinationIndex(parentTaskClient, Clock.systemUTC(), config, copyIndexCreatedListener);
} else {
copyIndexCreatedListener.onFailure(e);
}
}
);

ClientHelper.executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, GetIndexAction.INSTANCE,
ClientHelper.executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, parentTaskClient, GetIndexAction.INSTANCE,
new GetIndexRequest().indices(config.getDest().getIndex()), destIndexListener);
}

Expand All @@ -289,6 +295,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi
task.markAsCompleted();
return;
}
final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
// Update state to ANALYZING and start process
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
dataExtractorFactory -> {
Expand Down Expand Up @@ -325,14 +332,14 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi
// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
DataFrameDataExtractorFactory.createForDestinationIndex(parentTaskClient, config, dataExtractorFactoryListener);
},
dataExtractorFactoryListener::onFailure
);

// First we need to refresh the dest index to ensure data is searchable in case the job
// was stopped after reindexing was complete but before the index was refreshed.
executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, RefreshAction.INSTANCE,
executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, parentTaskClient, RefreshAction.INSTANCE,
new RefreshRequest(config.getDest().getIndex()), refreshListener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue;
Expand Down Expand Up @@ -75,7 +76,7 @@ public DataFrameAnalyticsTask(long id, String type, String action, TaskId parent
Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
DataFrameAnalyticsAuditor auditor, StartDataFrameAnalyticsAction.TaskParams taskParams) {
super(id, type, action, MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + taskParams.getId(), parentTask, headers);
this.client = Objects.requireNonNull(client);
this.client = new ParentTaskAssigningClient(Objects.requireNonNull(client), parentTask);
this.clusterService = Objects.requireNonNull(clusterService);
this.analyticsManager = Objects.requireNonNull(analyticsManager);
this.auditor = Objects.requireNonNull(auditor);
Expand Down Expand Up @@ -109,6 +110,12 @@ protected void onCancelled() {
markAsCompleted();
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
// onCancelled implements graceful shutdown of children
return false;
}

@Override
public void markAsCompleted() {
// It is possible that the stop API has been called in the meantime and that
Expand Down
Loading

0 comments on commit f7809dd

Please sign in to comment.