Skip to content

Commit

Permalink
Support task resource tracking in OpenSearch (#3982) (#4087)
Browse files Browse the repository at this point in the history
* [Backport 2.x] Support task resource tracking in OpenSearch

* Reopens changes from #2639 (reverted in #3046) to add a framework for task resource tracking. Currently, SearchTask and SearchShardTask support resource tracking but it can be extended to any other task.

* Fixed a race-condition when Task is unregistered before its threads are stopped

* Improved error handling and simplified task resource tracking completion listener

* Avoid registering listeners on already completed tasks

Signed-off-by: Ketan Verma <[email protected]>
  • Loading branch information
ketanv3 authored Aug 2, 2022
1 parent ed0af68 commit e23a87b
Show file tree
Hide file tree
Showing 30 changed files with 1,543 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ public void onTaskUnregistered(Task task) {}

@Override
public void waitForTaskCompletion(Task task) {}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
});
}
// Need to run the task in a separate thread because node client's .execute() is blocked by our task listener
Expand Down Expand Up @@ -651,6 +654,9 @@ public void waitForTaskCompletion(Task task) {
waitForWaitingToStart.countDown();
}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}

@Override
public void onTaskRegistered(Task task) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -65,8 +66,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) {

private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30);

private final TaskResourceTrackingService taskResourceTrackingService;

@Inject
public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) {
public TransportListTasksAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
TaskResourceTrackingService taskResourceTrackingService
) {
super(
ListTasksAction.NAME,
clusterService,
Expand All @@ -77,6 +85,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService
TaskInfo::new,
ThreadPool.Names.MANAGEMENT
);
this.taskResourceTrackingService = taskResourceTrackingService;
}

@Override
Expand Down Expand Up @@ -106,6 +115,8 @@ protected void processTasks(ListTasksRequest request, Consumer<Task> operation)
}
taskManager.waitForTaskCompletion(task, timeoutNanos);
});
} else {
operation = operation.andThen(taskResourceTrackingService::refreshResourceStats);
}
super.processTasks(request, operation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public SearchShardTask(long id, String type, String action, String description,
super(id, type, action, description, parentTaskId, headers);
}

@Override
public boolean supportsResourceTracking() {
return true;
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public final String getDescription() {
return descriptionSupplier.get();
}

@Override
public boolean supportsResourceTracking() {
return true;
}

/**
* Attach a {@link SearchProgressListener} to this task.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.action.ActionResponse;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskId;
Expand Down Expand Up @@ -93,31 +94,39 @@ public final Task execute(Request request, ActionListener<Response> listener) {
*/
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
final Task task;

try {
task = taskManager.register("transport", actionName, request);
} catch (TaskCancelledException e) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);

ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
}
}
}
});
});
} finally {
storedContext.close();
}

return task;
}

Expand All @@ -134,25 +143,30 @@ public final Task execute(Request request, TaskListener<Response> listener) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
}
}
}
});
});
} finally {
storedContext.close();
}
return task;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
import org.opensearch.script.ScriptMetadata;
import org.opensearch.snapshots.SnapshotsInfoService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.tasks.TaskResultsService;

import java.util.ArrayList;
Expand Down Expand Up @@ -396,6 +397,7 @@ protected void configure() {
bind(NodeMappingRefreshAction.class).asEagerSingleton();
bind(MappingUpdatedAction.class).asEagerSingleton();
bind(TaskResultsService.class).asEagerSingleton();
bind(TaskResourceTrackingService.class).asEagerSingleton();
bind(AllocationDeciders.class).toInstance(allocationDeciders);
bind(ShardsAllocator.class).toInstance(shardsAllocator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.index.ShardIndexingPressureMemoryManager;
import org.opensearch.index.ShardIndexingPressureSettings;
import org.opensearch.index.ShardIndexingPressureStore;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction;
import org.opensearch.action.admin.indices.close.TransportCloseIndexAction;
Expand Down Expand Up @@ -573,7 +574,8 @@ public void apply(Settings value, Settings current, Settings previous) {
ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS,
ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT,
ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS,
IndexingPressure.MAX_INDEXING_BYTES
IndexingPressure.MAX_INDEXING_BYTES,
TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.node.Node;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TaskAwareRunnable;

import java.util.List;
import java.util.Optional;
Expand All @@ -55,6 +57,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -177,6 +180,31 @@ public static OpenSearchThreadPoolExecutor newFixed(
);
}

public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
String name,
int size,
int initialQueueCapacity,
int minQueueSize,
int maxQueueSize,
int frameSize,
TimeValue targetedResponseTime,
ThreadFactory threadFactory,
ThreadContext contextHolder
) {
return newAutoQueueFixed(
name,
size,
initialQueueCapacity,
minQueueSize,
maxQueueSize,
frameSize,
targetedResponseTime,
threadFactory,
contextHolder,
null
);
}

/**
* Return a new executor that will automatically adjust the queue size based on queue throughput.
*
Expand All @@ -185,6 +213,7 @@ public static OpenSearchThreadPoolExecutor newFixed(
* @param minQueueSize minimum queue size that the queue can be adjusted to
* @param maxQueueSize maximum queue size that the queue can be adjusted to
* @param frameSize number of tasks during which stats are collected before adjusting queue size
* @param runnableTaskListener callback listener for a TaskAwareRunnable
*/
public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
String name,
Expand All @@ -195,17 +224,30 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
int frameSize,
TimeValue targetedResponseTime,
ThreadFactory threadFactory,
ThreadContext contextHolder
ThreadContext contextHolder,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {
if (initialQueueCapacity <= 0) {
throw new IllegalArgumentException(
"initial queue capacity for [" + name + "] executor must be positive, got: " + initialQueueCapacity
);
}

ResizableBlockingQueue<Runnable> queue = new ResizableBlockingQueue<>(
ConcurrentCollections.<Runnable>newBlockingQueue(),
initialQueueCapacity
);

Function<Runnable, WrappedRunnable> runnableWrapper;
if (runnableTaskListener != null) {
runnableWrapper = (runnable) -> {
TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener);
return new TimedRunnable(taskAwareRunnable);
};
} else {
runnableWrapper = TimedRunnable::new;
}

return new QueueResizingOpenSearchThreadPoolExecutor(
name,
size,
Expand All @@ -215,7 +257,7 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
queue,
minQueueSize,
maxQueueSize,
TimedRunnable::new,
runnableWrapper,
frameSize,
targetedResponseTime,
threadFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
Expand Down Expand Up @@ -135,16 +136,23 @@ public StoredContext stashContext() {
* This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
* Otherwise when context is stash, it should be empty.
*/

ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;

if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) {
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders(
threadContextStruct = threadContextStruct.putHeaders(
MapBuilder.<String, String>newMapBuilder()
.put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID))
.immutableMap()
);
threadLocal.set(threadContextStruct);
} else {
threadLocal.set(DEFAULT_CONTEXT);
}

if (context.transientHeaders.containsKey(TASK_ID)) {
threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID));
}

threadLocal.set(threadContextStruct);

return () -> {
// If the node and thus the threadLocal get closed while this task
// is still executing, we don't want this runnable to fail with an
Expand Down
Loading

0 comments on commit e23a87b

Please sign in to comment.