From a279a9e90bddcf043892ae4fc2362628da3db07b Mon Sep 17 00:00:00 2001 From: jimczi Date: Fri, 24 Jan 2020 15:51:25 +0100 Subject: [PATCH 1/2] Expose the logic to cancel task when the rest channel is closed This commit moves the logic that cancels search requests when the rest channel is closed to a generic client that can be used by other APIs. This will be useful for any rest action that wants to cancel the execution of a task if the underlying rest channel is closed by the client before completion. Relates #49931 Relates #50990 Relates #50990 --- ...er.java => RestCancellableNodeClient.java} | 113 +++++++++++------- .../rest/action/search/RestSearchAction.java | 6 +- ...va => RestCancellableNodeClientTests.java} | 46 ++++--- .../elasticsearch/test/ESIntegTestCase.java | 10 +- 4 files changed, 99 insertions(+), 76 deletions(-) rename server/src/main/java/org/elasticsearch/rest/action/{search/HttpChannelTaskHandler.java => RestCancellableNodeClient.java} (53%) rename server/src/test/java/org/elasticsearch/rest/action/{search/HttpChannelTaskHandlerTests.java => RestCancellableNodeClientTests.java} (84%) diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java similarity index 53% rename from server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java rename to server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java index 5864551854fca..ef296de66b8e1 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java @@ -17,54 +17,84 @@ * under the License. */ -package org.elasticsearch.rest.action.search; +package org.elasticsearch.rest.action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; -import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; -import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.client.Client; +import org.elasticsearch.client.FilterClient; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.node.NodeClient; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; +import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN; + /** - * This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from, - * so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed. + * A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel} + * is closed before completion. */ -public final class HttpChannelTaskHandler { +public class RestCancellableNodeClient extends FilterClient { + private static final Map httpChannels = new ConcurrentHashMap<>(); - public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler(); - //package private for testing - final Map httpChannels = new ConcurrentHashMap<>(); + private final NodeClient client; + private final HttpChannel httpChannel; - private HttpChannelTaskHandler() { + public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) { + super(client); + this.client = client; + this.httpChannel = httpChannel; } - void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request, - ActionType actionType, ActionListener listener) { + /** + * Returns the number of channels tracked globally. + */ + public static int getNumChannels() { + return httpChannels.size(); + } + + /** + * Returns the number of tasks tracked globally. + */ + static int getNumTasks() { + return httpChannels.values().stream() + .mapToInt(CloseListener::getNumTasks) + .sum(); + } - CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client)); + /** + * Returns the number of tasks tracked by the provided {@link HttpChannel}. + */ + static int getNumTasks(HttpChannel channel) { + CloseListener listener = httpChannels.get(channel); + return listener == null ? 0 : listener.getNumTasks(); + } + + @Override + public void doExecute( + ActionType action, Request request, ActionListener listener) { + CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener()); TaskHolder taskHolder = new TaskHolder(); - Task task = client.executeLocally(actionType, request, + Task task = client.executeLocally(action, request, new ActionListener<>() { @Override - public void onResponse(Response searchResponse) { + public void onResponse(Response response) { try { closeListener.unregisterTask(taskHolder); } finally { - listener.onResponse(searchResponse); + listener.onResponse(response); } } @@ -77,25 +107,28 @@ public void onFailure(Exception e) { } } }); - closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId())); + final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId()); + closeListener.registerTask(taskHolder, taskId); closeListener.maybeRegisterChannel(httpChannel); } - public int getNumChannels() { - return httpChannels.size(); + private void cancelTask(TaskId taskId) { + CancelTasksRequest req = new CancelTasksRequest() + .setTaskId(taskId) + .setReason("channel closed"); + // force the origin to execute the cancellation as a system user + new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {})); } - final class CloseListener implements ActionListener { - private final Client client; + private class CloseListener implements ActionListener { private final AtomicReference channel = new AtomicReference<>(); - private final Set taskIds = new HashSet<>(); + private final Set tasks = new HashSet<>(); - CloseListener(Client client) { - this.client = client; + CloseListener() { } int getNumTasks() { - return taskIds.size(); + return tasks.size(); } void maybeRegisterChannel(HttpChannel httpChannel) { @@ -111,35 +144,27 @@ void maybeRegisterChannel(HttpChannel httpChannel) { synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { taskHolder.taskId = taskId; if (taskHolder.completed == false) { - this.taskIds.add(taskId); + this.tasks.add(taskId); } } synchronized void unregisterTask(TaskHolder taskHolder) { if (taskHolder.taskId != null) { - this.taskIds.remove(taskHolder.taskId); + this.tasks.remove(taskHolder.taskId); } taskHolder.completed = true; } @Override - public synchronized void onResponse(Void aVoid) { - //When the channel gets closed it won't be reused: we can remove it from the map and forget about it. - CloseListener closeListener = httpChannels.remove(channel.get()); - assert closeListener != null : "channel not found in the map of tracked channels"; - for (TaskId taskId : taskIds) { - ThreadContext threadContext = client.threadPool().getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - // we stash any context here since this is an internal execution and should not leak any existing context information - threadContext.markAsSystemContext(); - ContextPreservingActionListener contextPreservingListener = new ContextPreservingActionListener<>( - threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {})); - CancelTasksRequest cancelTasksRequest = new CancelTasksRequest(); - cancelTasksRequest.setTaskId(taskId); - //We don't wait for cancel tasks to come back. Task cancellation is just best effort. - client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener); - } + public void onResponse(Void aVoid) { + final List toCancel; + synchronized (this) { + // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. + CloseListener closeListener = httpChannels.remove(channel.get()); + assert closeListener != null : "channel not found in the map of tracked channels"; + toCancel = new ArrayList<>(tasks); } + toCancel.stream().forEach(taskId -> cancelTask(taskId)); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java index 11dc9f89de532..5bdc0f3fffadf 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java @@ -21,7 +21,6 @@ import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.Booleans; @@ -32,6 +31,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestActions; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.rest.action.RestStatusToXContentListener; import org.elasticsearch.search.Scroll; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -100,8 +100,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC parseSearchRequest(searchRequest, request, parser, setSize)); return channel -> { - RestStatusToXContentListener listener = new RestStatusToXContentListener<>(channel); - HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener); + RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel()); + cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel)); }; } diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java similarity index 84% rename from server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java rename to server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java index 103981abdc41e..8121b31547599 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java @@ -17,7 +17,7 @@ * under the License. */ -package org.elasticsearch.rest.action.search; +package org.elasticsearch.rest.action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -45,7 +45,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CountDownLatch; @@ -56,13 +55,13 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -public class HttpChannelTaskHandlerTests extends ESTestCase { +public class RestCancellableNodeClientTests extends ESTestCase { private ThreadPool threadPool; @Before public void createThreadPool() { - threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName()); + threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName()); } @After @@ -77,8 +76,7 @@ public void stopThreadPool() { */ public void testCompletedTasks() throws Exception { try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) { - HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; - int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int totalSearches = 0; List> futures = new ArrayList<>(); int numChannels = randomIntBetween(1, 30); @@ -88,8 +86,8 @@ public void testCompletedTasks() throws Exception { totalSearches += numTasks; for (int j = 0; j < numTasks; j++) { PlainListenableActionFuture actionFuture = PlainListenableActionFuture.newListenableFuture(); - threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), - SearchAction.INSTANCE, actionFuture)); + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); + threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture)); futures.add(actionFuture); } } @@ -97,10 +95,8 @@ public void testCompletedTasks() throws Exception { future.get(); } //no channels get closed in this test, hence we expect as many channels as we created in the map - assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); - for (Map.Entry entry : httpChannelTaskHandler.httpChannels.entrySet()) { - assertEquals(0, entry.getValue().getNumTasks()); - } + assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(0, RestCancellableNodeClient.getNumTasks()); assertEquals(totalSearches, testClient.searchRequests.get()); } } @@ -110,9 +106,8 @@ public void testCompletedTasks() throws Exception { * removed and all of its corresponding tasks get cancelled. */ public void testCancelledTasks() throws Exception { - try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { - HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; - int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) { + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int numChannels = randomIntBetween(1, 30); int totalSearches = 0; List channels = new ArrayList<>(numChannels); @@ -121,18 +116,19 @@ public void testCancelledTasks() throws Exception { channels.add(channel); int numTasks = randomIntBetween(1, 30); totalSearches += numTasks; + RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel); for (int j = 0; j < numTasks; j++) { - httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); + client.execute(SearchAction.INSTANCE, new SearchRequest(), null); } - assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks()); + assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel)); } - assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels()); + assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels()); for (TestHttpChannel channel : channels) { channel.awaitClose(); } - assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); - assertEquals(totalSearches, testClient.searchRequests.get()); - assertEquals(totalSearches, testClient.cancelledTasks.size()); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(totalSearches, nodeClient.searchRequests.get()); + assertEquals(totalSearches, nodeClient.cancelledTasks.size()); } } @@ -144,8 +140,7 @@ public void testCancelledTasks() throws Exception { */ public void testChannelAlreadyClosed() { try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) { - HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE; - int initialHttpChannels = httpChannelTaskHandler.getNumChannels(); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int numChannels = randomIntBetween(1, 30); int totalSearches = 0; for (int i = 0; i < numChannels; i++) { @@ -154,12 +149,13 @@ public void testChannelAlreadyClosed() { channel.close(); int numTasks = randomIntBetween(1, 5); totalSearches += numTasks; + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); for (int j = 0; j < numTasks; j++) { //here the channel will be first registered, then straight-away removed from the map as the close listener is invoked - httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null); + client.execute(SearchAction.INSTANCE, new SearchRequest(), null); } } - assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels()); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); assertEquals(totalSearches, testClient.searchRequests.get()); assertEquals(totalSearches, testClient.cancelledTasks.size()); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index dd4f937039afc..b144cc8621b13 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -113,7 +113,7 @@ import org.elasticsearch.plugins.NetworkPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.rest.action.search.HttpChannelTaskHandler; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchHit; @@ -511,9 +511,11 @@ private static void clearClusters() throws Exception { restClient.close(); restClient = null; } - assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " + - HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0, - HttpChannelTaskHandler.INSTANCE.getNumChannels())); + assertBusy(() -> { + int numChannels = RestCancellableNodeClient.getNumChannels(); + assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName() + + " while there should be none", 0, numChannels); + }); } private void afterInternal(boolean afterClass) throws Exception { From d92322ca9b224d6d4a832ae93d1f05c16ead91be Mon Sep 17 00:00:00 2001 From: jimczi Date: Mon, 27 Jan 2020 09:40:22 +0100 Subject: [PATCH 2/2] address review --- .../rest/action/RestCancellableNodeClient.java | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java index ef296de66b8e1..224c266bd2b18 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java @@ -127,7 +127,7 @@ private class CloseListener implements ActionListener { CloseListener() { } - int getNumTasks() { + synchronized int getNumTasks() { return tasks.size(); } @@ -135,7 +135,7 @@ void maybeRegisterChannel(HttpChannel httpChannel) { if (channel.compareAndSet(null, httpChannel)) { //In case the channel is already closed when we register the listener, the listener will be immediately executed which will //remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it - //with the channel. This guarantees that the close listener is already in the map when the it gets registered to its + //with the channel. This guarantees that the close listener is already in the map when it gets registered to its //corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed. httpChannel.addCloseListener(this); } @@ -157,14 +157,19 @@ synchronized void unregisterTask(TaskHolder taskHolder) { @Override public void onResponse(Void aVoid) { + final HttpChannel httpChannel = channel.get(); + assert httpChannel != null : "channel not registered"; + // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. + CloseListener closeListener = httpChannels.remove(httpChannel); + assert closeListener != null : "channel not found in the map of tracked channels"; final List toCancel; synchronized (this) { - // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. - CloseListener closeListener = httpChannels.remove(channel.get()); - assert closeListener != null : "channel not found in the map of tracked channels"; toCancel = new ArrayList<>(tasks); + tasks.clear(); + } + for (TaskId taskId : toCancel) { + cancelTask(taskId); } - toCancel.stream().forEach(taskId -> cancelTask(taskId)); } @Override