Skip to content

Commit

Permalink
Expose the logic to cancel task when the rest channel is closed (#51423)
Browse files Browse the repository at this point in the history
This commit moves the logic that cancels search requests when the rest channel is closed
to a generic client that can be used by other APIs. This will be useful for any rest action
that wants to cancel the execution of a task if the underlying rest channel is closed by the
client before completion.

Relates #49931
Relates #50990
Relates #50990
  • Loading branch information
jimczi authored Jan 28, 2020
1 parent 9d2c579 commit 7e9153b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,84 @@
* under the License.
*/

package org.elasticsearch.rest.action.search;
package org.elasticsearch.rest.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;

/**
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from,
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed.
* A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel}
* is closed before completion.
*/
public final class HttpChannelTaskHandler {
public class RestCancellableNodeClient extends FilterClient {
private static final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();

private final NodeClient client;
private final HttpChannel httpChannel;

public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler();
//package private for testing
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) {
super(client);
this.client = client;
this.httpChannel = httpChannel;
}

/**
* Returns the number of channels tracked globally.
*/
public static int getNumChannels() {
return httpChannels.size();
}

private HttpChannelTaskHandler() {
/**
* Returns the number of tasks tracked globally.
*/
static int getNumTasks() {
return httpChannels.values().stream()
.mapToInt(CloseListener::getNumTasks)
.sum();
}

<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request,
ActionType<Response> actionType, ActionListener<Response> listener) {
/**
* Returns the number of tasks tracked by the provided {@link HttpChannel}.
*/
static int getNumTasks(HttpChannel channel) {
CloseListener listener = httpChannels.get(channel);
return listener == null ? 0 : listener.getNumTasks();
}

CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client));
@Override
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action, Request request, ActionListener<Response> listener) {
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener());
TaskHolder taskHolder = new TaskHolder();
Task task = client.executeLocally(actionType, request,
Task task = client.executeLocally(action, request,
new ActionListener<>() {
@Override
public void onResponse(Response searchResponse) {
public void onResponse(Response response) {
try {
closeListener.unregisterTask(taskHolder);
} finally {
listener.onResponse(searchResponse);
listener.onResponse(response);
}
}

Expand All @@ -77,32 +107,35 @@ public void onFailure(Exception e) {
}
}
});
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId());
closeListener.registerTask(taskHolder, taskId);
closeListener.maybeRegisterChannel(httpChannel);
}

public int getNumChannels() {
return httpChannels.size();
private void cancelTask(TaskId taskId) {
CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(taskId)
.setReason("channel closed");
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
}

final class CloseListener implements ActionListener<Void> {
private final Client client;
private class CloseListener implements ActionListener<Void> {
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> taskIds = new HashSet<>();
private final Set<TaskId> tasks = new HashSet<>();

CloseListener(Client client) {
this.client = client;
CloseListener() {
}

int getNumTasks() {
return taskIds.size();
synchronized int getNumTasks() {
return tasks.size();
}

void maybeRegisterChannel(HttpChannel httpChannel) {
if (channel.compareAndSet(null, httpChannel)) {
//In case the channel is already closed when we register the listener, the listener will be immediately executed which will
//remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it
//with the channel. This guarantees that the close listener is already in the map when the it gets registered to its
//with the channel. This guarantees that the close listener is already in the map when it gets registered to its
//corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed.
httpChannel.addCloseListener(this);
}
Expand All @@ -111,34 +144,31 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId;
if (taskHolder.completed == false) {
this.taskIds.add(taskId);
this.tasks.add(taskId);
}
}

synchronized void unregisterTask(TaskHolder taskHolder) {
if (taskHolder.taskId != null) {
this.taskIds.remove(taskHolder.taskId);
this.tasks.remove(taskHolder.taskId);
}
taskHolder.completed = true;
}

@Override
public synchronized void onResponse(Void aVoid) {
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(channel.get());
public void onResponse(Void aVoid) {
final HttpChannel httpChannel = channel.get();
assert httpChannel != null : "channel not registered";
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(httpChannel);
assert closeListener != null : "channel not found in the map of tracked channels";
for (TaskId taskId : taskIds) {
ThreadContext threadContext = client.threadPool().getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>(
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {}));
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
cancelTasksRequest.setTaskId(taskId);
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
}
final List<TaskId> toCancel;
synchronized (this) {
toCancel = new ArrayList<>(tasks);
tasks.clear();
}
for (TaskId taskId : toCancel) {
cancelTask(taskId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Booleans;
Expand All @@ -32,6 +31,7 @@
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestActions;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestStatusToXContentListener;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -100,8 +100,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
parseSearchRequest(searchRequest, request, parser, setSize));

return channel -> {
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel);
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener);
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

package org.elasticsearch.rest.action.search;
package org.elasticsearch.rest.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
Expand Down Expand Up @@ -45,7 +45,6 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
Expand All @@ -56,13 +55,13 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

public class HttpChannelTaskHandlerTests extends ESTestCase {
public class RestCancellableNodeClientTests extends ESTestCase {

private ThreadPool threadPool;

@Before
public void createThreadPool() {
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName());
threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName());
}

@After
Expand All @@ -77,8 +76,7 @@ public void stopThreadPool() {
*/
public void testCompletedTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int totalSearches = 0;
List<Future<?>> futures = new ArrayList<>();
int numChannels = randomIntBetween(1, 30);
Expand All @@ -88,19 +86,17 @@ public void testCompletedTasks() throws Exception {
totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) {
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(),
SearchAction.INSTANCE, actionFuture));
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture));
futures.add(actionFuture);
}
}
for (Future<?> future : futures) {
future.get();
}
//no channels get closed in this test, hence we expect as many channels as we created in the map
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) {
assertEquals(0, entry.getValue().getNumTasks());
}
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(0, RestCancellableNodeClient.getNumTasks());
assertEquals(totalSearches, testClient.searchRequests.get());
}
}
Expand All @@ -110,9 +106,8 @@ public void testCompletedTasks() throws Exception {
* removed and all of its corresponding tasks get cancelled.
*/
public void testCancelledTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) {
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
List<TestHttpChannel> channels = new ArrayList<>(numChannels);
Expand All @@ -121,18 +116,19 @@ public void testCancelledTasks() throws Exception {
channels.add(channel);
int numTasks = randomIntBetween(1, 30);
totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel);
for (int j = 0; j < numTasks; j++) {
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
}
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks());
assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel));
}
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
for (TestHttpChannel channel : channels) {
channel.awaitClose();
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, nodeClient.searchRequests.get());
assertEquals(totalSearches, nodeClient.cancelledTasks.size());
}
}

Expand All @@ -144,8 +140,7 @@ public void testCancelledTasks() throws Exception {
*/
public void testChannelAlreadyClosed() {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
for (int i = 0; i < numChannels; i++) {
Expand All @@ -154,12 +149,13 @@ public void testChannelAlreadyClosed() {
channel.close();
int numTasks = randomIntBetween(1, 5);
totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
for (int j = 0; j < numTasks; j++) {
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
}
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.SearchHit;
Expand Down Expand Up @@ -511,9 +511,11 @@ private static void clearClusters() throws Exception {
restClient.close();
restClient = null;
}
assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " +
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0,
HttpChannelTaskHandler.INSTANCE.getNumChannels()));
assertBusy(() -> {
int numChannels = RestCancellableNodeClient.getNumChannels();
assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName()
+ " while there should be none", 0, numChannels);
});
}

private void afterInternal(boolean afterClass) throws Exception {
Expand Down

0 comments on commit 7e9153b

Please sign in to comment.