Skip to content

Commit

Permalink
Expose the logic to cancel task when the rest channel is closed
Browse files Browse the repository at this point in the history
This commit moves the logic that cancels search requests when the rest channel is closed
to a generic client that can be used by other APIs. This will be useful for any rest action
that wants to cancel the execution of a task if the underlying rest channel is closed by the
client before completion.

Relates elastic#49931
Relates elastic#50990
Relates elastic#50990
  • Loading branch information
jimczi committed Jan 24, 2020
1 parent da450f1 commit a279a9e
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,84 @@
* under the License.
*/

package org.elasticsearch.rest.action.search;
package org.elasticsearch.rest.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;

/**
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from,
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed.
* A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel}
* is closed before completion.
*/
public final class HttpChannelTaskHandler {
public class RestCancellableNodeClient extends FilterClient {
private static final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();

public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler();
//package private for testing
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
private final NodeClient client;
private final HttpChannel httpChannel;

private HttpChannelTaskHandler() {
public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) {
super(client);
this.client = client;
this.httpChannel = httpChannel;
}

<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request,
ActionType<Response> actionType, ActionListener<Response> listener) {
/**
* Returns the number of channels tracked globally.
*/
public static int getNumChannels() {
return httpChannels.size();
}

/**
* Returns the number of tasks tracked globally.
*/
static int getNumTasks() {
return httpChannels.values().stream()
.mapToInt(CloseListener::getNumTasks)
.sum();
}

CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client));
/**
* Returns the number of tasks tracked by the provided {@link HttpChannel}.
*/
static int getNumTasks(HttpChannel channel) {
CloseListener listener = httpChannels.get(channel);
return listener == null ? 0 : listener.getNumTasks();
}

@Override
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action, Request request, ActionListener<Response> listener) {
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener());
TaskHolder taskHolder = new TaskHolder();
Task task = client.executeLocally(actionType, request,
Task task = client.executeLocally(action, request,
new ActionListener<>() {
@Override
public void onResponse(Response searchResponse) {
public void onResponse(Response response) {
try {
closeListener.unregisterTask(taskHolder);
} finally {
listener.onResponse(searchResponse);
listener.onResponse(response);
}
}

Expand All @@ -77,25 +107,28 @@ public void onFailure(Exception e) {
}
}
});
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId());
closeListener.registerTask(taskHolder, taskId);
closeListener.maybeRegisterChannel(httpChannel);
}

public int getNumChannels() {
return httpChannels.size();
private void cancelTask(TaskId taskId) {
CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(taskId)
.setReason("channel closed");
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
}

final class CloseListener implements ActionListener<Void> {
private final Client client;
private class CloseListener implements ActionListener<Void> {
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> taskIds = new HashSet<>();
private final Set<TaskId> tasks = new HashSet<>();

CloseListener(Client client) {
this.client = client;
CloseListener() {
}

int getNumTasks() {
return taskIds.size();
return tasks.size();
}

void maybeRegisterChannel(HttpChannel httpChannel) {
Expand All @@ -111,35 +144,27 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId;
if (taskHolder.completed == false) {
this.taskIds.add(taskId);
this.tasks.add(taskId);
}
}

synchronized void unregisterTask(TaskHolder taskHolder) {
if (taskHolder.taskId != null) {
this.taskIds.remove(taskHolder.taskId);
this.tasks.remove(taskHolder.taskId);
}
taskHolder.completed = true;
}

@Override
public synchronized void onResponse(Void aVoid) {
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(channel.get());
assert closeListener != null : "channel not found in the map of tracked channels";
for (TaskId taskId : taskIds) {
ThreadContext threadContext = client.threadPool().getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>(
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {}));
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
cancelTasksRequest.setTaskId(taskId);
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
}
public void onResponse(Void aVoid) {
final List<TaskId> toCancel;
synchronized (this) {
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(channel.get());
assert closeListener != null : "channel not found in the map of tracked channels";
toCancel = new ArrayList<>(tasks);
}
toCancel.stream().forEach(taskId -> cancelTask(taskId));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Booleans;
Expand All @@ -32,6 +31,7 @@
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestActions;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestStatusToXContentListener;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -100,8 +100,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
parseSearchRequest(searchRequest, request, parser, setSize));

return channel -> {
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel);
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener);
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

package org.elasticsearch.rest.action.search;
package org.elasticsearch.rest.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
Expand Down Expand Up @@ -45,7 +45,6 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
Expand All @@ -56,13 +55,13 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

public class HttpChannelTaskHandlerTests extends ESTestCase {
public class RestCancellableNodeClientTests extends ESTestCase {

private ThreadPool threadPool;

@Before
public void createThreadPool() {
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName());
threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName());
}

@After
Expand All @@ -77,8 +76,7 @@ public void stopThreadPool() {
*/
public void testCompletedTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int totalSearches = 0;
List<Future<?>> futures = new ArrayList<>();
int numChannels = randomIntBetween(1, 30);
Expand All @@ -88,19 +86,17 @@ public void testCompletedTasks() throws Exception {
totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) {
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(),
SearchAction.INSTANCE, actionFuture));
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture));
futures.add(actionFuture);
}
}
for (Future<?> future : futures) {
future.get();
}
//no channels get closed in this test, hence we expect as many channels as we created in the map
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) {
assertEquals(0, entry.getValue().getNumTasks());
}
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(0, RestCancellableNodeClient.getNumTasks());
assertEquals(totalSearches, testClient.searchRequests.get());
}
}
Expand All @@ -110,9 +106,8 @@ public void testCompletedTasks() throws Exception {
* removed and all of its corresponding tasks get cancelled.
*/
public void testCancelledTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) {
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
List<TestHttpChannel> channels = new ArrayList<>(numChannels);
Expand All @@ -121,18 +116,19 @@ public void testCancelledTasks() throws Exception {
channels.add(channel);
int numTasks = randomIntBetween(1, 30);
totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel);
for (int j = 0; j < numTasks; j++) {
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
}
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks());
assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel));
}
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
for (TestHttpChannel channel : channels) {
channel.awaitClose();
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, nodeClient.searchRequests.get());
assertEquals(totalSearches, nodeClient.cancelledTasks.size());
}
}

Expand All @@ -144,8 +140,7 @@ public void testCancelledTasks() throws Exception {
*/
public void testChannelAlreadyClosed() {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
for (int i = 0; i < numChannels; i++) {
Expand All @@ -154,12 +149,13 @@ public void testChannelAlreadyClosed() {
channel.close();
int numTasks = randomIntBetween(1, 5);
totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
for (int j = 0; j < numTasks; j++) {
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
}
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.SearchHit;
Expand Down Expand Up @@ -511,9 +511,11 @@ private static void clearClusters() throws Exception {
restClient.close();
restClient = null;
}
assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " +
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0,
HttpChannelTaskHandler.INSTANCE.getNumChannels()));
assertBusy(() -> {
int numChannels = RestCancellableNodeClient.getNumChannels();
assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName()
+ " while there should be none", 0, numChannels);
});
}

private void afterInternal(boolean afterClass) throws Exception {
Expand Down

0 comments on commit a279a9e

Please sign in to comment.