Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel in-flight search tasks due to search backpressure with 429 sta… #6634

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.PluginsService;
Expand Down Expand Up @@ -209,7 +210,9 @@ private void verifyCancellationException(ShardSearchFailure[] failures) {
final Throwable topFailureCause = searchFailure.getCause();
assertTrue(
searchFailure.toString(),
topFailureCause instanceof TransportException || topFailureCause instanceof TaskCancelledException
topFailureCause instanceof TransportException
|| topFailureCause instanceof TaskCancelledException
|| topFailureCause instanceof OpenSearchRejectedExecutionException
);
if (topFailureCause instanceof TransportException) {
assertTrue(topFailureCause.getCause() instanceof TaskCancelledException);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.backpressure.settings.NodeDuressSettings;
import org.opensearch.search.backpressure.settings.SearchBackpressureSettings;
import org.opensearch.search.backpressure.settings.SearchShardTaskSettings;
Expand All @@ -51,6 +53,7 @@
import java.util.function.Supplier;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;

Expand Down Expand Up @@ -129,6 +132,7 @@ public void testSearchTaskCancellationWithHighElapsedTime() throws InterruptedEx
assertNotNull("SearchTask should have been cancelled with TaskCancelledException", caughtException);
MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class));
MatcherAssert.assertThat(caughtException.getMessage(), containsString("elapsed time exceeded"));
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

public void testSearchShardTaskCancellationWithHighElapsedTime() throws InterruptedException {
Expand All @@ -146,6 +150,7 @@ public void testSearchShardTaskCancellationWithHighElapsedTime() throws Interrup
assertNotNull("SearchShardTask should have been cancelled with TaskCancelledException", caughtException);
MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class));
MatcherAssert.assertThat(caughtException.getMessage(), containsString("elapsed time exceeded"));
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

public void testSearchTaskCancellationWithHighCpu() throws InterruptedException {
Expand Down Expand Up @@ -177,6 +182,7 @@ public void testSearchTaskCancellationWithHighCpu() throws InterruptedException
assertNotNull("SearchTask should have been cancelled with TaskCancelledException", caughtException);
MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class));
MatcherAssert.assertThat(caughtException.getMessage(), containsString("cpu usage exceeded"));
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

public void testSearchShardTaskCancellationWithHighCpu() throws InterruptedException {
Expand All @@ -194,6 +200,7 @@ public void testSearchShardTaskCancellationWithHighCpu() throws InterruptedExcep
assertNotNull("SearchShardTask should have been cancelled with TaskCancelledException", caughtException);
MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class));
MatcherAssert.assertThat(caughtException.getMessage(), containsString("cpu usage exceeded"));
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

public void testSearchTaskCancellationWithHighHeapUsage() throws InterruptedException {
Expand Down Expand Up @@ -250,6 +257,7 @@ public void testSearchTaskCancellationWithHighHeapUsage() throws InterruptedExce
assertNotNull("SearchTask should have been cancelled with TaskCancelledException", caughtException);
MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class));
MatcherAssert.assertThat(caughtException.getMessage(), containsString("heap usage exceeded"));
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

public void testSearchShardTaskCancellationWithHighHeapUsage() throws InterruptedException {
Expand Down Expand Up @@ -278,6 +286,7 @@ public void testSearchShardTaskCancellationWithHighHeapUsage() throws Interrupte
assertNotNull("SearchShardTask should have been cancelled with TaskCancelledException", caughtException);
MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class));
MatcherAssert.assertThat(caughtException.getMessage(), containsString("heap usage exceeded"));
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

public void testSearchCancellationWithBackpressureDisabled() throws InterruptedException {
Expand All @@ -294,6 +303,7 @@ public void testSearchCancellationWithBackpressureDisabled() throws InterruptedE

Exception caughtException = listener.getException();
assertNull("SearchShardTask shouldn't have cancelled for monitor_only mode", caughtException);
MatcherAssert.assertThat(((TaskCancelledException) caughtException).status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
}

private static class ExceptionCatchingListener implements ActionListener<TestResponse> {
Expand Down Expand Up @@ -405,7 +415,8 @@ protected void doExecute(Task task, TestRequest request, ActionListener<TestResp
&& (System.nanoTime() - startTime) < TIMEOUT.getNanos());

if (cancellableTask.isCancelled()) {
throw new TaskCancelledException(cancellableTask.getReasonCancelled());
String reason = cancellableTask.getReasonCancelled();
throw new TaskCancelledException(new OpenSearchRejectedExecutionException(reason));
} else {
listener.onResponse(new TestResponse());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.opensearch.rest.RestStatus;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.snapshots.SnapshotInUseDeletionException;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.transport.TcpTransport;

import java.io.IOException;
Expand Down Expand Up @@ -258,7 +259,10 @@ protected Map<String, List<String>> getHeaders() {
*/
public RestStatus status() {
Throwable cause = unwrapCause();
if (cause == this) {
if (cause.getCause() instanceof TaskCancelledException
&& ((TaskCancelledException) cause.getCause()).status() == RestStatus.TOO_MANY_REQUESTS) {
return ((TaskCancelledException) cause.getCause()).status();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid TaskCancelledException in OpenSearchException ? This looks like an anti pattern.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternate here is to make it generic and throw the underlying cause's status in all cases till level 2 . We need to think about whether it is the right behavior and also make sure it doesn't cause any regression in existing code (ITs would make sure of that). Something on the below lines.

    public RestStatus status() {
        Throwable cause = unwrapCause();
        if (cause == this) {
            return RestStatus.INTERNAL_SERVER_ERROR;
        } else {
            if (cause.getCause() != cause) {
                return ExceptionsHelper.status(cause.getCause());
            }
            return ExceptionsHelper.status(cause);
        }
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other better option is to override this in TaskCancelledException .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PritLadani Can you explore the second suggestion to do the override in TaskCancelledException? The code snippet above doesn't look quite right because it calls an "unwrapCause()" method but then proceeds to do another level of cause unwrapping. I'm really concerned about unintended side effects of that approach, and I also don't have a ton of confidence that ITs would catch every possible regression.

} else if (cause == this) {
return RestStatus.INTERNAL_SERVER_ERROR;
} else {
return ExceptionsHelper.status(cause);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.ShardId;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.transport.Transport;

import java.util.ArrayDeque;
Expand Down Expand Up @@ -370,6 +372,15 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
: OpenSearchException.guessRootCauses(shardSearchFailures[0].getCause())[0];
logger.debug(() -> new ParameterizedMessage("All shards failed for phase: [{}]", getName()), cause);
onPhaseFailure(currentPhase, "all shards failed", cause);
} else if (getTask().isCancelled()) {
// checking if the task handling the search got cancelled. Adding this check only while starting the next phase to avoid
// slowing down the search operation
String reason = getTask().getReasonCancelled();
onPhaseFailure(
currentPhase,
"SearchTask was cancelled",
new TaskCancelledException(new OpenSearchRejectedExecutionException("cancelled task with reason: " + reason))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layered exceptions wrapping does not look right, I think we should introduce dedicated exception (fe TaskBackpressureException or alike) to differentiate between cancellation modes.

);
} else {
Boolean allowPartialResults = request.allowPartialSearchResults();
assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ void doRun() {
}

for (TaskCancellation taskCancellation : getTaskCancellations(cancellableTasks)) {
Class<? extends SearchBackpressureTask> taskType = getTaskType(taskCancellation.getTask());
logger.debug(
"[{} mode] cancelling task [{}] due to high resource consumption [{}]",
"[{} mode] cancelling [{}] task with id [{}] due to high resource consumption [{}]",
mode.getName(),
taskCancellation.getTask().getId(),
taskCancellation.getReasonString()
Expand All @@ -208,8 +209,6 @@ void doRun() {
continue;
}

Class<? extends SearchBackpressureTask> taskType = getTaskType(taskCancellation.getTask());

// Independently remove tokens from both token buckets.
SearchBackpressureState searchBackpressureState = searchBackpressureStates.get(taskType);
boolean rateLimitReached = searchBackpressureState.getRateLimiter().request() == false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@
import org.opensearch.common.collect.HppcMaps;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.tasks.TaskCancelledException;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.tasks.TaskCancellationService.throwTaskCancelledException;

/**
* Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
* The additional information is used to better compare the scores coming from all the shards, which depend on local factors (e.g. idf)
Expand All @@ -64,7 +65,7 @@ public void execute(SearchContext context) {
@Override
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
throwTaskCancelledException(context.getTask().getReasonCancelled());
}
TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq);
if (ts != null) {
Expand All @@ -76,7 +77,7 @@ public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq)
@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
throwTaskCancelledException(context.getTask().getReasonCancelled());
}
CollectionStatistics cs = super.collectionStatistics(field);
if (cs != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.lookup.SourceLookup;
import org.opensearch.tasks.TaskCancelledException;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -86,6 +85,7 @@
import java.util.function.Function;

import static java.util.Collections.emptyMap;
import static org.opensearch.tasks.TaskCancellationService.throwTaskCancelledException;

/**
* Fetch phase of a search request, used to fetch the actual top matching documents to be returned to the client, identified
Expand All @@ -109,7 +109,7 @@ public void execute(SearchContext context) {
}

if (context.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
throwTaskCancelledException(context.getTask().getReasonCancelled());
}

if (context.docIdsToLoadSize() == 0) {
Expand Down Expand Up @@ -141,7 +141,7 @@ public void execute(SearchContext context) {
boolean hasSequentialDocs = hasSequentialDocs(docs);
for (int index = 0; index < context.docIdsToLoadSize(); index++) {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
throwTaskCancelledException(context.getTask().getReasonCancelled());
}
int docId = docs[index].docId;
try {
Expand Down Expand Up @@ -184,7 +184,7 @@ public void execute(SearchContext context) {
}
}
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
throwTaskCancelledException(context.getTask().getReasonCancelled());
}

TotalHits totalHits = context.queryResult().getTotalHits();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
import org.opensearch.search.rescore.RescorePhase;
import org.opensearch.search.sort.SortAndFormats;
import org.opensearch.search.suggest.SuggestPhase;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand All @@ -77,6 +76,7 @@
import static org.opensearch.search.query.QueryCollectorContext.createMinScoreCollectorContext;
import static org.opensearch.search.query.QueryCollectorContext.createMultiCollectorContext;
import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;
import static org.opensearch.tasks.TaskCancellationService.throwTaskCancelledException;

/**
* Query phase of a search request, used to run the query and get back from each shard information about the matching documents
Expand Down Expand Up @@ -112,7 +112,7 @@ public void preProcess(SearchContext context) {
cancellation = context.searcher().addQueryCancellation(() -> {
SearchShardTask task = context.getTask();
if (task != null && task.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled());
throwTaskCancelledException(context.getTask().getReasonCancelled());
}
});
} else {
Expand Down Expand Up @@ -253,7 +253,7 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q
searcher.addQueryCancellation(() -> {
SearchShardTask task = searchContext.getTask();
if (task != null && task.isCancelled()) {
throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled());
throwTaskCancelledException(task.getReasonCancelled());
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.EmptyTransportResponseHandler;
import org.opensearch.transport.TransportChannel;
Expand All @@ -63,6 +64,8 @@
*/
public class TaskCancellationService {
public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
public static final String REASON_PARENT_CANCELLED_HIGH_RESOURCE_CONSUMPTION =
"The parent task was cancelled due to high resource consumption";
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
Expand All @@ -88,7 +91,13 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
logger.trace("cancelling task [{}] and its descendants", taskId);
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes = taskManager.startBanOnChildrenNodes(task.getId(), () -> {
String cancelChildTaskReason = reason;
// if parent task gets cancelled due to high resource consumption, child task should be cancelled saying parent task was
// cancelled
if (reason.contains("usage exceeded")) {
cancelChildTaskReason = REASON_PARENT_CANCELLED_HIGH_RESOURCE_CONSUMPTION;
}
Collection<DiscoveryNode> childrenNodes = taskManager.startBanOnChildrenNodes(task.getId(), cancelChildTaskReason, () -> {
logger.trace("child tasks of parent [{}] are completed", taskId);
groupedListener.onResponse(null);
});
Expand All @@ -97,7 +106,7 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
groupedListener.onResponse(null);
});
StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
setBanOnNodes(cancelChildTaskReason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
Expand Down Expand Up @@ -257,4 +266,15 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
}
}
}

public static void throwTaskCancelledException(String reason) {
if (isRejection(reason)) {
throw new TaskCancelledException(new OpenSearchRejectedExecutionException("cancelled task with reason: " + reason));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just throw OpenSearchRejectedExecutionService if task is rejected ?

}
throw new TaskCancelledException(reason);
}

private static boolean isRejection(String reason) {
return (reason.contains("usage exceeded") || REASON_PARENT_CANCELLED_HIGH_RESOURCE_CONSUMPTION.equals(reason));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make usage exceeded a constant and use that instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

}
}
Loading