Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the logic to cancel task when the rest channel is closed #51423

Merged
merged 2 commits into from
Jan 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,84 @@
* under the License.
*/

package org.elasticsearch.rest.action.search;
package org.elasticsearch.rest.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;

/**
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from,
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed.
* A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel}
* is closed before completion.
*/
public final class HttpChannelTaskHandler {
public class RestCancellableNodeClient extends FilterClient {
private static final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();

private final NodeClient client;
private final HttpChannel httpChannel;

public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler();
//package private for testing
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) {
super(client);
this.client = client;
this.httpChannel = httpChannel;
}

/**
* Returns the number of channels tracked globally.
*/
public static int getNumChannels() {
return httpChannels.size();
}

private HttpChannelTaskHandler() {
/**
* Returns the number of tasks tracked globally.
*/
static int getNumTasks() {
return httpChannels.values().stream()
.mapToInt(CloseListener::getNumTasks)
.sum();
}

<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request,
ActionType<Response> actionType, ActionListener<Response> listener) {
/**
* Returns the number of tasks tracked by the provided {@link HttpChannel}.
*/
static int getNumTasks(HttpChannel channel) {
CloseListener listener = httpChannels.get(channel);
return listener == null ? 0 : listener.getNumTasks();
}

CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client));
@Override
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action, Request request, ActionListener<Response> listener) {
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener());
TaskHolder taskHolder = new TaskHolder();
Task task = client.executeLocally(actionType, request,
Task task = client.executeLocally(action, request,
new ActionListener<>() {
@Override
public void onResponse(Response searchResponse) {
public void onResponse(Response response) {
try {
closeListener.unregisterTask(taskHolder);
} finally {
listener.onResponse(searchResponse);
listener.onResponse(response);
}
}

Expand All @@ -77,32 +107,35 @@ public void onFailure(Exception e) {
}
}
});
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId());
closeListener.registerTask(taskHolder, taskId);
closeListener.maybeRegisterChannel(httpChannel);
}

public int getNumChannels() {
return httpChannels.size();
private void cancelTask(TaskId taskId) {
CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(taskId)
.setReason("channel closed");
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
}

final class CloseListener implements ActionListener<Void> {
private final Client client;
private class CloseListener implements ActionListener<Void> {
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> taskIds = new HashSet<>();
private final Set<TaskId> tasks = new HashSet<>();

CloseListener(Client client) {
this.client = client;
CloseListener() {
}

int getNumTasks() {
return taskIds.size();
synchronized int getNumTasks() {
return tasks.size();
}

void maybeRegisterChannel(HttpChannel httpChannel) {
if (channel.compareAndSet(null, httpChannel)) {
//In case the channel is already closed when we register the listener, the listener will be immediately executed which will
//remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it
//with the channel. This guarantees that the close listener is already in the map when the it gets registered to its
//with the channel. This guarantees that the close listener is already in the map when it gets registered to its
//corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed.
httpChannel.addCloseListener(this);
}
Expand All @@ -111,34 +144,31 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId;
if (taskHolder.completed == false) {
this.taskIds.add(taskId);
this.tasks.add(taskId);
}
}

synchronized void unregisterTask(TaskHolder taskHolder) {
if (taskHolder.taskId != null) {
this.taskIds.remove(taskHolder.taskId);
this.tasks.remove(taskHolder.taskId);
}
taskHolder.completed = true;
}

@Override
public synchronized void onResponse(Void aVoid) {
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(channel.get());
public void onResponse(Void aVoid) {
final HttpChannel httpChannel = channel.get();
assert httpChannel != null : "channel not registered";
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(httpChannel);
assert closeListener != null : "channel not found in the map of tracked channels";
for (TaskId taskId : taskIds) {
ThreadContext threadContext = client.threadPool().getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>(
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {}));
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
cancelTasksRequest.setTaskId(taskId);
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
}
final List<TaskId> toCancel;
synchronized (this) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to protect httpChannels so we could move the synchronized keyword two lines below?

toCancel = new ArrayList<>(tasks);
tasks.clear();
}
for (TaskId taskId : toCancel) {
cancelTask(taskId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Booleans;
Expand All @@ -32,6 +31,7 @@
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestActions;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestStatusToXContentListener;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -100,8 +100,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
parseSearchRequest(searchRequest, request, parser, setSize));

return channel -> {
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel);
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener);
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

package org.elasticsearch.rest.action.search;
package org.elasticsearch.rest.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
Expand Down Expand Up @@ -45,7 +45,6 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
Expand All @@ -56,13 +55,13 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

public class HttpChannelTaskHandlerTests extends ESTestCase {
public class RestCancellableNodeClientTests extends ESTestCase {

private ThreadPool threadPool;

@Before
public void createThreadPool() {
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName());
threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName());
}

@After
Expand All @@ -77,8 +76,7 @@ public void stopThreadPool() {
*/
public void testCompletedTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int totalSearches = 0;
List<Future<?>> futures = new ArrayList<>();
int numChannels = randomIntBetween(1, 30);
Expand All @@ -88,19 +86,17 @@ public void testCompletedTasks() throws Exception {
totalSearches += numTasks;
for (int j = 0; j < numTasks; j++) {
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(),
SearchAction.INSTANCE, actionFuture));
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture));
futures.add(actionFuture);
}
}
for (Future<?> future : futures) {
future.get();
}
//no channels get closed in this test, hence we expect as many channels as we created in the map
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) {
assertEquals(0, entry.getValue().getNumTasks());
}
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(0, RestCancellableNodeClient.getNumTasks());
assertEquals(totalSearches, testClient.searchRequests.get());
}
}
Expand All @@ -110,9 +106,8 @@ public void testCompletedTasks() throws Exception {
* removed and all of its corresponding tasks get cancelled.
*/
public void testCancelledTasks() throws Exception {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) {
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
List<TestHttpChannel> channels = new ArrayList<>(numChannels);
Expand All @@ -121,18 +116,19 @@ public void testCancelledTasks() throws Exception {
channels.add(channel);
int numTasks = randomIntBetween(1, 30);
totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel);
for (int j = 0; j < numTasks; j++) {
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
}
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks());
assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel));
}
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
for (TestHttpChannel channel : channels) {
channel.awaitClose();
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, nodeClient.searchRequests.get());
assertEquals(totalSearches, nodeClient.cancelledTasks.size());
}
}

Expand All @@ -144,8 +140,7 @@ public void testCancelledTasks() throws Exception {
*/
public void testChannelAlreadyClosed() {
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numChannels = randomIntBetween(1, 30);
int totalSearches = 0;
for (int i = 0; i < numChannels; i++) {
Expand All @@ -154,12 +149,13 @@ public void testChannelAlreadyClosed() {
channel.close();
int numTasks = randomIntBetween(1, 5);
totalSearches += numTasks;
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
for (int j = 0; j < numTasks; j++) {
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
}
}
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(totalSearches, testClient.searchRequests.get());
assertEquals(totalSearches, testClient.cancelledTasks.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.SearchHit;
Expand Down Expand Up @@ -511,9 +511,11 @@ private static void clearClusters() throws Exception {
restClient.close();
restClient = null;
}
assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " +
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0,
HttpChannelTaskHandler.INSTANCE.getNumChannels()));
assertBusy(() -> {
int numChannels = RestCancellableNodeClient.getNumChannels();
assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName()
+ " while there should be none", 0, numChannels);
});
}

private void afterInternal(boolean afterClass) throws Exception {
Expand Down