Skip to content

Commit

Permalink
Make recovery APIs cancellable
Browse files Browse the repository at this point in the history
  • Loading branch information
DaveCTurner committed Feb 18, 2021
1 parent 146f7be commit 1ea1426
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.indices.recovery.RecoveryAction;
import org.elasticsearch.action.admin.indices.recovery.TransportRecoveryAction;
import org.elasticsearch.action.admin.indices.recovery.TransportRecoveryActionHelper;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.transport.TransportService;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Semaphore;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.not;

public class IndicesRecoveryRestCancellationIT extends HttpSmokeTestCase {

public void testIndicesRecoveryRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_recovery"));
}

public void testCatRecoveryRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_cat/recovery"));
}

private void runTest(Request request) throws Exception {

createIndex("test");
ensureGreen("test");

final List<Semaphore> operationBlocks = new ArrayList<>();
for (final TransportRecoveryAction transportRecoveryAction : internalCluster().getInstances(TransportRecoveryAction.class)) {
final Semaphore operationBlock = new Semaphore(1);
operationBlocks.add(operationBlock);
TransportRecoveryActionHelper.setOnShardOperation(transportRecoveryAction, () -> {
try {
operationBlock.acquire();
} catch (InterruptedException e) {
throw new AssertionError(e);
}
operationBlock.release();
});
}
assertThat(operationBlocks, not(empty()));

final List<Releasable> releasables = new ArrayList<>();
try {
for (final Semaphore operationBlock : operationBlocks) {
operationBlock.acquire();
releasables.add(operationBlock::release);
}

final PlainActionFuture<Void> future = new PlainActionFuture<>();
logger.info("--> sending request");
final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
future.onResponse(null);
}

@Override
public void onFailure(Exception exception) {
future.onFailure(exception);
}
});

logger.info("--> waiting for task to start");
assertBusy(() -> {
final List<TaskInfo> tasks = client().admin().cluster().prepareListTasks().get().getTasks();
assertTrue(tasks.toString(), tasks.stream().anyMatch(t -> t.getAction().startsWith(RecoveryAction.NAME)));
});

logger.info("--> waiting for at least one task to hit a block");
assertBusy(() -> assertTrue(operationBlocks.stream().anyMatch(Semaphore::hasQueuedThreads)));

logger.info("--> cancelling request");
cancellable.cancel();
expectThrows(CancellationException.class, future::actionGet);

logger.info("--> checking that all tasks are marked as cancelled");
assertBusy(() -> {
boolean foundTask = false;
for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
for (CancellableTask cancellableTask : transportService.getTaskManager().getCancellableTasks().values()) {
if (cancellableTask.getAction().startsWith(RecoveryAction.NAME)) {
foundTask = true;
assertTrue("task " + cancellableTask.getId() + " not cancelled", cancellableTask.isCancelled());
}
}
}
assertTrue("found no cancellable tasks", foundTask);
});
} finally {
Releasables.close(releasables);
}

logger.info("--> checking that all tasks have finished");
assertBusy(() -> {
final List<TaskInfo> tasks = client().admin().cluster().prepareListTasks().get().getTasks();
assertTrue(tasks.toString(), tasks.stream().noneMatch(t -> t.getAction().startsWith(RecoveryAction.NAME)));
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;

/**
* Request for recovery information
Expand Down Expand Up @@ -90,4 +94,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(detailed);
out.writeBoolean(activeOnly);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardsIterator;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -88,6 +90,8 @@ protected RecoveryRequest readRequestFrom(StreamInput in) throws IOException {

@Override
protected RecoveryState shardOperation(RecoveryRequest request, ShardRouting shardRouting, Task task) {
assert task instanceof CancellableTask;
runOnShardOperation();
IndexService indexService = indicesService.indexServiceSafe(shardRouting.shardId().getIndex());
IndexShard indexShard = indexService.getShard(shardRouting.shardId().id());
return indexShard.recoveryState();
Expand All @@ -107,4 +111,19 @@ protected ClusterBlockException checkGlobalBlock(ClusterState state, RecoveryReq
protected ClusterBlockException checkRequestBlock(ClusterState state, RecoveryRequest request, String[] concreteIndices) {
return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices);
}

@Nullable // unless running tests that inject extra behaviour
private volatile Runnable onShardOperation;

private void runOnShardOperation() {
final Runnable onShardOperation = this.onShardOperation;
if (onShardOperation != null) {
onShardOperation.run();
}
}

// exposed for tests: inject some extra behaviour that runs when shardOperation() is called
void setOnShardOperation(@Nullable Runnable onShardOperation) {
this.onShardOperation = onShardOperation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestToXContentListener;

import java.io.IOException;
Expand Down Expand Up @@ -50,7 +51,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
recoveryRequest.detailed(request.paramAsBoolean("detailed", false));
recoveryRequest.activeOnly(request.paramAsBoolean("active_only", false));
recoveryRequest.indicesOptions(IndicesOptions.fromRequest(request, recoveryRequest.indicesOptions()));
return channel -> client.admin().indices().recoveries(recoveryRequest, new RestToXContentListener<>(channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel())
.admin().indices().recoveries(recoveryRequest, new RestToXContentListener<>(channel));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestResponseListener;

import java.util.Comparator;
Expand Down Expand Up @@ -63,7 +64,8 @@ public RestChannelConsumer doCatRequest(final RestRequest request, final NodeCli
recoveryRequest.activeOnly(request.paramAsBoolean("active_only", false));
recoveryRequest.indicesOptions(IndicesOptions.fromRequest(request, recoveryRequest.indicesOptions()));

return channel -> client.admin().indices().recoveries(recoveryRequest, new RestResponseListener<RecoveryResponse>(channel) {
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel())
.admin().indices().recoveries(recoveryRequest, new RestResponseListener<RecoveryResponse>(channel) {
@Override
public RestResponse buildResponse(final RecoveryResponse response) throws Exception {
return RestTable.buildResponse(buildRecoveryTable(request, response), channel);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.admin.indices.recovery;

/**
* Helper methods for {@link TransportRecoveryAction}.
*/
public class TransportRecoveryActionHelper {

/**
* Helper method for tests to call {@link TransportRecoveryAction#setOnShardOperation}.
*/
public static void setOnShardOperation(TransportRecoveryAction transportRecoveryAction, Runnable setOnShardOperation) {
transportRecoveryAction.setOnShardOperation(setOnShardOperation);
}
}

0 comments on commit 1ea1426

Please sign in to comment.