diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 479daffc832b5..dbef148a798b2 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.Scheduler.Cancellable; import org.elasticsearch.threadpool.ThreadPool; @@ -42,6 +43,7 @@ * Task that tracks the progress of a currently running {@link SearchRequest}. */ final class AsyncSearchTask extends SearchTask { + private final CancellableTask submitTask; private final AsyncSearchId searchId; private final Client client; private final ThreadPool threadPool; @@ -67,7 +69,7 @@ final class AsyncSearchTask extends SearchTask { * @param id The id of the task. * @param type The type of the task. * @param action The action name. - * @param parentTaskId The parent task id. + * @param submitTask The task that submitted the async search. * @param originHeaders All the request context headers. * @param taskHeaders The filtered request headers for the task. * @param searchId The {@link AsyncSearchId} of the task. @@ -77,7 +79,7 @@ final class AsyncSearchTask extends SearchTask { AsyncSearchTask(long id, String type, String action, - TaskId parentTaskId, + CancellableTask submitTask, TimeValue keepAlive, Map originHeaders, Map taskHeaders, @@ -85,7 +87,8 @@ final class AsyncSearchTask extends SearchTask { Client client, ThreadPool threadPool, Supplier aggReduceContextSupplier) { - super(id, type, action, "async_search", parentTaskId, taskHeaders); + super(id, type, action, "async_search", TaskId.EMPTY_TASK_ID, taskHeaders); + this.submitTask = submitTask; this.expirationTimeMillis = getStartTime() + keepAlive.getMillis(); this.originHeaders = originHeaders; this.searchId = searchId; @@ -212,13 +215,13 @@ private void internalAddCompletionListener(ActionListener l final Cancellable cancellable; try { - cancellable = threadPool.schedule(() -> { + cancellable = threadPool.schedule(threadPool.preserveContext(() -> { if (hasRun.compareAndSet(false, true)) { // timeout occurred before completion removeCompletionListener(id); listener.onResponse(getResponse()); } - }, waitForCompletion, "generic"); + }), waitForCompletion, "generic"); } catch (EsRejectedExecutionException exc) { listener.onFailure(exc); return; @@ -291,10 +294,13 @@ private AsyncSearchResponse getResponse() { return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis); } - // cancels the task if it expired - private void checkExpiration() { + // checks if the search task should be cancelled + private void checkCancellation() { long now = System.currentTimeMillis(); - if (expirationTimeMillis < now) { + if (expirationTimeMillis < now || submitTask.isCancelled()) { + // we cancel the search task if the initial submit task was cancelled, + // this is needed because the task cancellation mechanism doesn't + // handle the cancellation of grand-children. cancelTask(() -> {}); } } @@ -302,30 +308,31 @@ private void checkExpiration() { class Listener extends SearchProgressActionListener { @Override protected void onQueryResult(int shardIndex) { - checkExpiration(); + checkCancellation(); } @Override protected void onFetchResult(int shardIndex) { - checkExpiration(); + checkCancellation(); } @Override protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { // best effort to cancel expired tasks - checkExpiration(); - searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget)); + checkCancellation(); + searchResponse.get().addShardFailure(shardIndex, + new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null)); } @Override protected void onFetchFailure(int shardIndex, Exception exc) { - checkExpiration(); + checkCancellation(); } @Override protected void onListShards(List shards, List skipped, Clusters clusters, boolean fetchPhase) { // best effort to cancel expired tasks - checkExpiration(); + checkCancellation(); searchResponse.compareAndSet(null, new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier)); executeInitListeners(); @@ -334,7 +341,7 @@ protected void onListShards(List shards, List skipped, @Override public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { // best effort to cancel expired tasks - checkExpiration(); + checkCancellation(); searchResponse.get().updatePartialResponse(shards.size(), new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs, null, null, false, null, reducePhase), aggs == null); @@ -343,7 +350,7 @@ public void onPartialReduce(List shards, TotalHits totalHits, Inter @Override public void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { // best effort to cancel expired tasks - checkExpiration(); + checkCancellation(); searchResponse.get().updatePartialResponse(shards.size(), new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs, null, null, false, null, reducePhase), true); diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java index 032c6114a1a44..641e7ea8f8dfb 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java @@ -66,7 +66,7 @@ public TransportSubmitAsyncSearchAction(ClusterService clusterService, @Override protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener submitListener) { CancellableTask submitTask = (CancellableTask) task; - final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive()); + final SearchRequest searchRequest = createSearchRequest(request, submitTask, request.getKeepAlive()); AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest); searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener()); searchTask.addCompletionListener( @@ -81,7 +81,7 @@ public void onResponse(AsyncSearchResponse searchResponse) { // the user cancelled the submit so we don't store anything // and propagate the failure Exception cause = new TaskCancelledException(submitTask.getReasonCancelled()); - onFatalFailure(searchTask, cause, false, submitListener); + onFatalFailure(searchTask, cause, searchResponse.isRunning(), submitListener); } else { final String docId = searchTask.getSearchId().getDocId(); // creates the fallback response if the node crashes/restarts in the middle of the request @@ -129,7 +129,7 @@ public void onFailure(Exception exc) { }, request.getWaitForCompletion()); } - private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, long parentTaskId, TimeValue keepAlive) { + private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, CancellableTask submitTask, TimeValue keepAlive) { String docID = UUIDs.randomBase64UUID(); Map originHeaders = nodeClient.threadPool().getThreadContext().getHeaders(); SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) { @@ -138,16 +138,15 @@ public AsyncSearchTask createTask(long id, String type, String action, TaskId pa AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id)); Supplier aggReduceContextSupplier = () -> requestToAggReduceContextBuilder.apply(request.getSearchRequest()); - return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId, - store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier); + return new AsyncSearchTask(id, type, action, submitTask, keepAlive, originHeaders, + taskHeaders, searchId, store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier); } }; - searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId)); return searchRequest; } private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener listener) { - if (shouldCancel) { + if (shouldCancel && task.isCancelled() == false) { task.cancelTask(() -> { try { task.addCompletionListener(finalResponse -> taskManager.unregister(task)); diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java index 373030fe12efa..ffacf2398e107 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java @@ -253,4 +253,28 @@ public void testNoIndex() throws Exception { ElasticsearchException exc = response.getFailure(); assertThat(exc.getMessage(), containsString("no such index")); } + + public void testCancellation() throws Exception { + SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName); + request.getSearchRequest().source( + new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test")) + ); + request.setWaitForCompletion(TimeValue.timeValueMillis(1)); + AsyncSearchResponse response = submitAsyncSearch(request); + assertNotNull(response.getSearchResponse()); + assertTrue(response.isRunning()); + assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards)); + assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0)); + assertThat(response.getSearchResponse().getFailedShards(), equalTo(0)); + + response = getAsyncSearch(response.getId()); + assertNotNull(response.getSearchResponse()); + assertTrue(response.isRunning()); + assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards)); + assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0)); + assertThat(response.getSearchResponse().getFailedShards(), equalTo(0)); + + deleteAsyncSearch(response.getId()); + ensureTaskRemoval(response.getId()); + } } diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java index e5e65a832db8e..9c4c6757d67cd 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java @@ -5,11 +5,7 @@ */ package org.elasticsearch.xpack.search; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.Weight; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskResponse; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup; @@ -18,22 +14,12 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.ParsingException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.ObjectParser; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.AbstractQueryBuilder; -import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.reindex.ReindexPlugin; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.PluginsService; -import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.TaskId; @@ -48,15 +34,11 @@ import org.elasticsearch.xpack.ilm.IndexLifecycle; import java.io.Closeable; -import java.io.IOException; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.Comparator; import java.util.Iterator; -import java.util.List; import java.util.Map; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -73,7 +55,7 @@ interface SearchResponseIterator extends Iterator, Closeabl @Override protected Collection> nodePlugins() { return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, IndexLifecycle.class, - QueryBlockPlugin.class, ReindexPlugin.class); + SearchTestPlugin.class, ReindexPlugin.class); } /** @@ -152,14 +134,14 @@ protected SearchResponseIterator assertBlockingIterator(String indexName, .collect( Collectors.toMap( Function.identity(), - id -> new ShardIdLatch(id, new CountDownLatch(1), failures.decrementAndGet() >= 0) + id -> new ShardIdLatch(id, failures.decrementAndGet() >= 0) ) ); ShardIdLatch[] shardLatchArray = shardLatchMap.values().stream() - .sorted(Comparator.comparing(ShardIdLatch::shard)) + .sorted(Comparator.comparing(ShardIdLatch::shardId)) .toArray(ShardIdLatch[]::new); resetPluginsLatch(shardLatchMap); - request.getSearchRequest().source().query(new BlockQueryBuilder(shardLatchMap)); + request.getSearchRequest().source().query(new BlockingQueryBuilder(shardLatchMap)); final AsyncSearchResponse initial = client().execute(SubmitAsyncSearchAction.INSTANCE, request).get(); @@ -197,7 +179,7 @@ private AsyncSearchResponse doNext() throws Exception { int step = shardIndex == 0 ? progressStep+1 : progressStep-1; int index = 0; while (index < step && shardIndex < shardLatchArray.length) { - if (shardLatchArray[shardIndex].shouldFail == false) { + if (shardLatchArray[shardIndex].shouldFail() == false) { ++index; } shardLatchArray[shardIndex++].countDown(); @@ -242,8 +224,8 @@ private AsyncSearchResponse doNext() throws Exception { @Override public void close() { Arrays.stream(shardLatchArray).forEach(shard -> { - if (shard.latch.getCount() == 1) { - shard.latch.countDown(); + if (shard.getCount() == 1) { + shard.countDown(); } }); } @@ -252,143 +234,7 @@ public void close() { private void resetPluginsLatch(Map newLatch) { for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) { - pluginsService.filterPlugins(QueryBlockPlugin.class).forEach(p -> p.reset(newLatch)); - } - } - - public static class QueryBlockPlugin extends Plugin implements SearchPlugin { - private Map shardsLatch; - - public QueryBlockPlugin() { - this.shardsLatch = null; - } - - public void reset(Map newLatch) { - shardsLatch = newLatch; - } - - @Override - public List> getQueries() { - return Collections.singletonList( - new QuerySpec<>("block_match_all", - in -> new BlockQueryBuilder(in, shardsLatch), - p -> BlockQueryBuilder.fromXContent(p, shardsLatch)) - ); - } - } - - private static class BlockQueryBuilder extends AbstractQueryBuilder { - public static final String NAME = "block_match_all"; - private final Map shardsLatch; - - private BlockQueryBuilder(Map shardsLatch) { - super(); - this.shardsLatch = shardsLatch; - } - - BlockQueryBuilder(StreamInput in, Map shardsLatch) throws IOException { - super(in); - this.shardsLatch = shardsLatch; - } - - private BlockQueryBuilder() { - this.shardsLatch = null; - } - - @Override - protected void doWriteTo(StreamOutput out) {} - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); - builder.endObject(); - } - - private static final ObjectParser PARSER = new ObjectParser<>(NAME, BlockQueryBuilder::new); - - public static BlockQueryBuilder fromXContent(XContentParser parser, Map shardsLatch) { - try { - PARSER.apply(parser, null); - return new BlockQueryBuilder(shardsLatch); - } catch (IllegalArgumentException e) { - throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); - } - } - - @Override - protected Query doToQuery(QueryShardContext context) { - final Query delegate = Queries.newMatchAllQuery(); - return new Query() { - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - if (shardsLatch != null) { - try { - final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId())); - latch.await(); - if (latch.shouldFail) { - throw new IOException("boum"); - } - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - return delegate.createWeight(searcher, scoreMode, boost); - } - - @Override - public String toString(String field) { - return delegate.toString(field); - } - - @Override - public boolean equals(Object obj) { - return false; - } - - @Override - public int hashCode() { - return 0; - } - }; - } - - @Override - protected boolean doEquals(BlockQueryBuilder other) { - return false; - } - - @Override - protected int doHashCode() { - return 0; - } - - @Override - public String getWriteableName() { - return NAME; - } - } - - private static class ShardIdLatch { - private final ShardId shard; - private final CountDownLatch latch; - private final boolean shouldFail; - - private ShardIdLatch(ShardId shard, CountDownLatch latch, boolean shouldFail) { - this.shard = shard; - this.latch = latch; - this.shouldFail = shouldFail; - } - - ShardId shard() { - return shard; - } - - void countDown() { - latch.countDown(); - } - - void await() throws InterruptedException { - latch.await(); + pluginsService.filterPlugins(SearchTestPlugin.class).forEach(p -> p.resetQueryLatch(newLatch)); } } } diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java index 8c3f57883ec14..c795f02239759 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; @@ -27,6 +28,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -35,6 +37,23 @@ public class AsyncSearchTaskTests extends ESTestCase { private ThreadPool threadPool; + private static class TestTask extends CancellableTask { + private TestTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { + super(id, type, action, description, parentTaskId, headers); + } + + @Override + public boolean shouldCancelChildrenOnCancellation() { + return false; + } + } + + private static TestTask createSubmitTask() { + return new TestTask(0L, "", "", "test", new TaskId("node1", 0), Collections.emptyMap()); + } + + + @Before public void beforeTest() { threadPool = new TestThreadPool(getTestName()); @@ -46,7 +65,7 @@ public void afterTest() { } public void testWaitForInit() throws InterruptedException { - AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), + AsyncSearchTask task = new AsyncSearchTask(0L, "", "", createSubmitTask(), TimeValue.timeValueHours(1), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, null); int numShards = randomIntBetween(0, 10); @@ -86,7 +105,7 @@ public void onFailure(Exception e) { } public void testWithFailure() throws InterruptedException { - AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), + AsyncSearchTask task = new AsyncSearchTask(0L, "", "", createSubmitTask(), TimeValue.timeValueHours(1), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, null); int numThreads = randomIntBetween(1, 10); @@ -114,7 +133,7 @@ public void onFailure(Exception e) { } public void testWaitForCompletion() throws InterruptedException { - AsyncSearchTask task = new AsyncSearchTask(0L, "", "", new TaskId("node1", 0), TimeValue.timeValueHours(1), + AsyncSearchTask task = new AsyncSearchTask(0L, "", "", createSubmitTask(), TimeValue.timeValueHours(1), Collections.emptyMap(), Collections.emptyMap(), new AsyncSearchId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, null); int numShards = randomIntBetween(0, 10); diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/BlockingQueryBuilder.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/BlockingQueryBuilder.java new file mode 100644 index 0000000000000..5939cc10458c3 --- /dev/null +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/BlockingQueryBuilder.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.search; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.index.shard.ShardId; + +import java.io.IOException; +import java.util.Map; + +/** + * A query builder that blocks shard execution based on the provided {@link ShardIdLatch}. + */ +class BlockingQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "block"; + private final Map shardsLatch; + + BlockingQueryBuilder(Map shardsLatch) { + super(); + this.shardsLatch = shardsLatch; + } + + BlockingQueryBuilder(StreamInput in, Map shardsLatch) throws IOException { + super(in); + this.shardsLatch = shardsLatch; + } + + BlockingQueryBuilder() { + this.shardsLatch = null; + } + + @Override + protected void doWriteTo(StreamOutput out) {} + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.endObject(); + } + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, BlockingQueryBuilder::new); + + public static BlockingQueryBuilder fromXContent(XContentParser parser, Map shardsLatch) { + try { + PARSER.apply(parser, null); + return new BlockingQueryBuilder(shardsLatch); + } catch (IllegalArgumentException e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + @Override + protected Query doToQuery(QueryShardContext context) { + final Query delegate = Queries.newMatchAllQuery(); + return new Query() { + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (shardsLatch != null) { + try { + final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId())); + latch.await(); + if (latch.shouldFail()) { + throw new IOException("boum"); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + return delegate.createWeight(searcher, scoreMode, boost); + } + + @Override + public String toString(String field) { + return delegate.toString(field); + } + + @Override + public boolean equals(Object obj) { + return false; + } + + @Override + public int hashCode() { + return 0; + } + }; + } + + @Override + protected boolean doEquals(BlockingQueryBuilder other) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java new file mode 100644 index 0000000000000..354c3c4ec340a --- /dev/null +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.search; + +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * An aggregation builder that blocks shard search action until the task is cancelled. + */ +public class CancellingAggregationBuilder extends AbstractAggregationBuilder { + static final String NAME = "cancel"; + static final int SLEEP_TIME = 10; + + public CancellingAggregationBuilder(String name) { + super(name); + } + + public CancellingAggregationBuilder(StreamInput in) throws IOException { + super(in); + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metaData) { + return new CancellingAggregationBuilder(name); + } + + @Override + public String getType() { + return NAME; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, false, (args, name) -> new CancellingAggregationBuilder(name)); + + + static CancellingAggregationBuilder fromXContent(String aggName, XContentParser parser) { + try { + return PARSER.apply(parser, aggName); + } catch (IllegalArgumentException e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + @Override + @SuppressWarnings("unchecked") + protected AggregatorFactory doBuild(QueryShardContext queryShardContext, AggregatorFactory parent, + AggregatorFactories.Builder subfactoriesBuilder) throws IOException { + final FilterAggregationBuilder filterAgg = new FilterAggregationBuilder(name, QueryBuilders.matchAllQuery()); + filterAgg.subAggregations(subfactoriesBuilder); + final AggregatorFactory factory = filterAgg.build(queryShardContext, parent); + return new AggregatorFactory(name, queryShardContext, parent, subfactoriesBuilder, metaData) { + @Override + protected Aggregator createInternal(SearchContext searchContext, + Aggregator parent, + boolean collectsFromSingleBucket, + List pipelineAggregators, + Map metaData) throws IOException { + while (searchContext.isCancelled() == false) { + try { + Thread.sleep(SLEEP_TIME); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + return factory.create(searchContext, parent, collectsFromSingleBucket); + } + }; + } +} diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/SearchTestPlugin.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/SearchTestPlugin.java new file mode 100644 index 0000000000000..88daac239fad1 --- /dev/null +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/SearchTestPlugin.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.search; + +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class SearchTestPlugin extends Plugin implements SearchPlugin { + private Map shardsLatch; + + public SearchTestPlugin() { + this.shardsLatch = null; + } + + public void resetQueryLatch(Map newLatch) { + shardsLatch = newLatch; + } + + @Override + public List> getQueries() { + return Collections.singletonList( + new QuerySpec<>(BlockingQueryBuilder.NAME, + in -> new BlockingQueryBuilder(in, shardsLatch), + p -> BlockingQueryBuilder.fromXContent(p, shardsLatch)) + ); + } + + @Override + public List getAggregations() { + return Collections.singletonList(new AggregationSpec(CancellingAggregationBuilder.NAME, CancellingAggregationBuilder::new, + CancellingAggregationBuilder.PARSER).addResultReader(InternalFilter::new)); + } +} diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/ShardIdLatch.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/ShardIdLatch.java new file mode 100644 index 0000000000000..dd171b3102894 --- /dev/null +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/ShardIdLatch.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.search; + +import org.elasticsearch.index.shard.ShardId; + +import java.util.concurrent.CountDownLatch; + +class ShardIdLatch extends CountDownLatch { + private final ShardId shard; + private final boolean shouldFail; + + ShardIdLatch(ShardId shard, boolean shouldFail) { + super(1); + this.shard = shard; + this.shouldFail = shouldFail; + } + + ShardId shardId() { + return shard; + } + + boolean shouldFail() { + return shouldFail; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/SubmitAsyncSearchRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/SubmitAsyncSearchRequest.java index a6397cbf08c3b..6717c766f03e8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/SubmitAsyncSearchRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/SubmitAsyncSearchRequest.java @@ -150,7 +150,8 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId, return new CancellableTask(id, type, action, toString(), parentTaskId, headers) { @Override public boolean shouldCancelChildrenOnCancellation() { - return true; + // we cancel the underlying search action explicitly in the submit action + return false; } }; }