Skip to content

Commit

Permalink
Improve async search's tasks cancellation (#53799)
Browse files Browse the repository at this point in the history
This commit adds an explicit cancellation of the search task if
the initial async search submit task is cancelled (connection closed by the user).
This was previously done through the cancellation of the parent task but we don't
handle grand-children cancellation yet so we have to manually cancel the search task
in order to ensure that shard actions are cancelled too.
This change can be considered as a workaround until #50990 is fixed.
  • Loading branch information
jimczi authored Mar 24, 2020
1 parent aed8ce7 commit 68f4297
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
* Task that tracks the progress of a currently running {@link SearchRequest}.
*/
final class AsyncSearchTask extends SearchTask {
private final BooleanSupplier checkSubmitCancellation;
private final AsyncSearchId searchId;
private final Client client;
private final ThreadPool threadPool;
Expand All @@ -68,6 +70,7 @@ final class AsyncSearchTask extends SearchTask {
* @param type The type of the task.
* @param action The action name.
* @param parentTaskId The parent task id.
* @param checkSubmitCancellation A boolean supplier that checks if the submit task has been cancelled.
* @param originHeaders All the request context headers.
* @param taskHeaders The filtered request headers for the task.
* @param searchId The {@link AsyncSearchId} of the task.
Expand All @@ -78,6 +81,7 @@ final class AsyncSearchTask extends SearchTask {
String type,
String action,
TaskId parentTaskId,
BooleanSupplier checkSubmitCancellation,
TimeValue keepAlive,
Map<String, String> originHeaders,
Map<String, String> taskHeaders,
Expand All @@ -86,6 +90,7 @@ final class AsyncSearchTask extends SearchTask {
ThreadPool threadPool,
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
super(id, type, action, "async_search", parentTaskId, taskHeaders);
this.checkSubmitCancellation = checkSubmitCancellation;
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.originHeaders = originHeaders;
this.searchId = searchId;
Expand Down Expand Up @@ -212,13 +217,13 @@ private void internalAddCompletionListener(ActionListener<AsyncSearchResponse> l

final Cancellable cancellable;
try {
cancellable = threadPool.schedule(() -> {
cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
if (hasRun.compareAndSet(false, true)) {
// timeout occurred before completion
removeCompletionListener(id);
listener.onResponse(getResponse());
}
}, waitForCompletion, "generic");
}), waitForCompletion, "generic");
} catch (EsRejectedExecutionException exc) {
listener.onFailure(exc);
return;
Expand Down Expand Up @@ -291,41 +296,45 @@ private AsyncSearchResponse getResponse() {
return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis);
}

// cancels the task if it expired
private void checkExpiration() {
// checks if the search task should be cancelled
private void checkCancellation() {
long now = System.currentTimeMillis();
if (expirationTimeMillis < now) {
if (expirationTimeMillis < now || checkSubmitCancellation.getAsBoolean()) {
// we cancel the search task if the initial submit task was cancelled,
// this is needed because the task cancellation mechanism doesn't
// handle the cancellation of grand-children.
cancelTask(() -> {});
}
}

class Listener extends SearchProgressActionListener {
@Override
protected void onQueryResult(int shardIndex) {
checkExpiration();
checkCancellation();
}

@Override
protected void onFetchResult(int shardIndex) {
checkExpiration();
checkCancellation();
}

@Override
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
// best effort to cancel expired tasks
checkExpiration();
searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget));
checkCancellation();
searchResponse.get().addShardFailure(shardIndex,
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
}

@Override
protected void onFetchFailure(int shardIndex, Exception exc) {
checkExpiration();
checkCancellation();
}

@Override
protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
// best effort to cancel expired tasks
checkExpiration();
checkCancellation();
searchResponse.compareAndSet(null,
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier));
executeInitListeners();
Expand All @@ -334,7 +343,7 @@ protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped,
@Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks
checkExpiration();
checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), aggs == null);
Expand All @@ -343,7 +352,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
@Override
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks
checkExpiration();
checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public TransportSubmitAsyncSearchAction(ClusterService clusterService,
@Override
protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener<AsyncSearchResponse> submitListener) {
CancellableTask submitTask = (CancellableTask) task;
final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive());
final SearchRequest searchRequest = createSearchRequest(request, submitTask, request.getKeepAlive());
AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest);
searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener());
searchTask.addCompletionListener(
Expand All @@ -81,7 +81,7 @@ public void onResponse(AsyncSearchResponse searchResponse) {
// the user cancelled the submit so we don't store anything
// and propagate the failure
Exception cause = new TaskCancelledException(submitTask.getReasonCancelled());
onFatalFailure(searchTask, cause, false, submitListener);
onFatalFailure(searchTask, cause, searchResponse.isRunning(), submitListener);
} else {
final String docId = searchTask.getSearchId().getDocId();
// creates the fallback response if the node crashes/restarts in the middle of the request
Expand Down Expand Up @@ -129,7 +129,7 @@ public void onFailure(Exception exc) {
}, request.getWaitForCompletion());
}

private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, long parentTaskId, TimeValue keepAlive) {
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, CancellableTask submitTask, TimeValue keepAlive) {
String docID = UUIDs.randomBase64UUID();
Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders();
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
Expand All @@ -138,16 +138,17 @@ public AsyncSearchTask createTask(long id, String type, String action, TaskId pa
AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id));
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier =
() -> requestToAggReduceContextBuilder.apply(request.getSearchRequest());
return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId,
store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier);
return new AsyncSearchTask(id, type, action, parentTaskId,
() -> submitTask.isCancelled(), keepAlive, originHeaders, taskHeaders, searchId, store.getClient(),
nodeClient.threadPool(), aggReduceContextSupplier);
}
};
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId));
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), submitTask.getId()));
return searchRequest;
}

private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener<AsyncSearchResponse> listener) {
if (shouldCancel) {
if (shouldCancel && task.isCancelled() == false) {
task.cancelTask(() -> {
try {
task.addCompletionListener(finalResponse -> taskManager.unregister(task));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,28 @@ public void testNoIndex() throws Exception {
ElasticsearchException exc = response.getFailure();
assertThat(exc.getMessage(), containsString("no such index"));
}

public void testCancellation() throws Exception {
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
request.getSearchRequest().source(
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
);
request.setWaitForCompletion(TimeValue.timeValueMillis(1));
AsyncSearchResponse response = submitAsyncSearch(request);
assertNotNull(response.getSearchResponse());
assertTrue(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId());
assertNotNull(response.getSearchResponse());
assertTrue(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

deleteAsyncSearch(response.getId());
ensureTaskRemoval(response.getId());
}
}
Loading

0 comments on commit 68f4297

Please sign in to comment.