Skip to content

Commit

Permalink
Improve async search's tasks cancellation
Browse files Browse the repository at this point in the history
This commit adds an explicit cancellation of the search task if
the initial async search submit task is cancelled (connection closed by the user).
This was previously done through the cancellation of the parent task but we don't
handle grand-children cancellation yet so we have to manually cancel the search task
in order to ensure that shard actions are cancelled too.
This change can be considered as a workaround until elastic#50990 is fixed.
  • Loading branch information
jimczi committed Mar 19, 2020
1 parent 112ae9c commit d781dc7
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -42,6 +43,7 @@
* Task that tracks the progress of a currently running {@link SearchRequest}.
*/
final class AsyncSearchTask extends SearchTask {
private final CancellableTask submitTask;
private final AsyncSearchId searchId;
private final Client client;
private final ThreadPool threadPool;
Expand All @@ -67,7 +69,7 @@ final class AsyncSearchTask extends SearchTask {
* @param id The id of the task.
* @param type The type of the task.
* @param action The action name.
* @param parentTaskId The parent task id.
* @param submitTask The task that submitted the async search.
* @param originHeaders All the request context headers.
* @param taskHeaders The filtered request headers for the task.
* @param searchId The {@link AsyncSearchId} of the task.
Expand All @@ -77,15 +79,16 @@ final class AsyncSearchTask extends SearchTask {
AsyncSearchTask(long id,
String type,
String action,
TaskId parentTaskId,
CancellableTask submitTask,
TimeValue keepAlive,
Map<String, String> originHeaders,
Map<String, String> taskHeaders,
AsyncSearchId searchId,
Client client,
ThreadPool threadPool,
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
super(id, type, action, "async_search", parentTaskId, taskHeaders);
super(id, type, action, "async_search", TaskId.EMPTY_TASK_ID, taskHeaders);
this.submitTask = submitTask;
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.originHeaders = originHeaders;
this.searchId = searchId;
Expand Down Expand Up @@ -212,13 +215,13 @@ private void internalAddCompletionListener(ActionListener<AsyncSearchResponse> l

final Cancellable cancellable;
try {
cancellable = threadPool.schedule(() -> {
cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
if (hasRun.compareAndSet(false, true)) {
// timeout occurred before completion
removeCompletionListener(id);
listener.onResponse(getResponse());
}
}, waitForCompletion, "generic");
}), waitForCompletion, "generic");
} catch (EsRejectedExecutionException exc) {
listener.onFailure(exc);
return;
Expand Down Expand Up @@ -291,41 +294,45 @@ private AsyncSearchResponse getResponse() {
return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis);
}

// cancels the task if it expired
private void checkExpiration() {
// checks if the search task should be cancelled
private void checkCancellation() {
long now = System.currentTimeMillis();
if (expirationTimeMillis < now) {
if (expirationTimeMillis < now || submitTask.isCancelled()) {
// we cancel the search task if the initial submit task was cancelled,
// this is needed because the task cancellation mechanism doesn't
// handle the cancellation of grand-children.
cancelTask(() -> {});
}
}

class Listener extends SearchProgressActionListener {
@Override
protected void onQueryResult(int shardIndex) {
checkExpiration();
checkCancellation();
}

@Override
protected void onFetchResult(int shardIndex) {
checkExpiration();
checkCancellation();
}

@Override
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
// best effort to cancel expired tasks
checkExpiration();
searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget));
checkCancellation();
searchResponse.get().addShardFailure(shardIndex,
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
}

@Override
protected void onFetchFailure(int shardIndex, Exception exc) {
checkExpiration();
checkCancellation();
}

@Override
protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
// best effort to cancel expired tasks
checkExpiration();
checkCancellation();
searchResponse.compareAndSet(null,
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier));
executeInitListeners();
Expand All @@ -334,7 +341,7 @@ protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped,
@Override
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks
checkExpiration();
checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), aggs == null);
Expand All @@ -343,7 +350,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
@Override
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
// best effort to cancel expired tasks
checkExpiration();
checkCancellation();
searchResponse.get().updatePartialResponse(shards.size(),
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
null, null, false, null, reducePhase), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public TransportSubmitAsyncSearchAction(ClusterService clusterService,
@Override
protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener<AsyncSearchResponse> submitListener) {
CancellableTask submitTask = (CancellableTask) task;
final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive());
final SearchRequest searchRequest = createSearchRequest(request, submitTask, request.getKeepAlive());
AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest);
searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener());
searchTask.addCompletionListener(
Expand All @@ -81,7 +81,7 @@ public void onResponse(AsyncSearchResponse searchResponse) {
// the user cancelled the submit so we don't store anything
// and propagate the failure
Exception cause = new TaskCancelledException(submitTask.getReasonCancelled());
onFatalFailure(searchTask, cause, false, submitListener);
onFatalFailure(searchTask, cause, searchResponse.isRunning(), submitListener);
} else {
final String docId = searchTask.getSearchId().getDocId();
// creates the fallback response if the node crashes/restarts in the middle of the request
Expand Down Expand Up @@ -129,7 +129,7 @@ public void onFailure(Exception exc) {
}, request.getWaitForCompletion());
}

private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, long parentTaskId, TimeValue keepAlive) {
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, CancellableTask submitTask, TimeValue keepAlive) {
String docID = UUIDs.randomBase64UUID();
Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders();
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
Expand All @@ -138,16 +138,15 @@ public AsyncSearchTask createTask(long id, String type, String action, TaskId pa
AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id));
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier =
() -> requestToAggReduceContextBuilder.apply(request.getSearchRequest());
return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId,
store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier);
return new AsyncSearchTask(id, type, action, submitTask, keepAlive, originHeaders,
taskHeaders, searchId, store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier);
}
};
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId));
return searchRequest;
}

private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener<AsyncSearchResponse> listener) {
if (shouldCancel) {
if (shouldCancel && task.isCancelled() == false) {
task.cancelTask(() -> {
try {
task.addCompletionListener(finalResponse -> taskManager.unregister(task));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,28 @@ public void testNoIndex() throws Exception {
ElasticsearchException exc = response.getFailure();
assertThat(exc.getMessage(), containsString("no such index"));
}

public void testCancellation() throws Exception {
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
request.getSearchRequest().source(
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
);
request.setWaitForCompletion(TimeValue.timeValueMillis(1));
AsyncSearchResponse response = submitAsyncSearch(request);
assertNotNull(response.getSearchResponse());
assertTrue(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId());
assertNotNull(response.getSearchResponse());
assertTrue(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

deleteAsyncSearch(response.getId());
ensureTaskRemoval(response.getId());
}
}
Loading

0 comments on commit d781dc7

Please sign in to comment.