Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSQ: Improved worker cancellation. #17046

Merged
merged 8 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.FutureUtils;
Expand Down Expand Up @@ -203,6 +204,7 @@ public int getChannelNumber(int rowNumber, int numRows, int numChannels)
private final List<KeyColumn> sortKey = ImmutableList.of(new KeyColumn(KEY, KeyOrder.ASCENDING));

private List<List<Frame>> channelFrames;
private ListeningExecutorService innerExec;
private FrameProcessorExecutor exec;
private List<BlockingQueueFrameChannel> channels;

Expand All @@ -226,7 +228,7 @@ public void setupTrial()
frameReader = FrameReader.create(signature);

exec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
innerExec = MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat(getClass().getSimpleName()))
)
);
Expand Down Expand Up @@ -335,8 +337,8 @@ public void setupInvocation() throws IOException
@TearDown(Level.Trial)
public void tearDown() throws Exception
{
exec.getExecutorService().shutdownNow();
if (!exec.getExecutorService().awaitTermination(1, TimeUnit.MINUTES)) {
innerExec.shutdownNow();
if (!innerExec.awaitTermination(1, TimeUnit.MINUTES)) {
throw new ISE("Could not terminate executor after 1 minute");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
public class ControllerImpl implements Controller
{
private static final Logger log = new Logger(ControllerImpl.class);
private static final String RESULT_READER_CANCELLATION_ID = "result-reader";

private final String queryId;
private final MSQSpec querySpec;
Expand Down Expand Up @@ -2189,6 +2190,34 @@ private static void logKernelStatus(final String queryId, final ControllerQueryK
}
}

/**
* Create a result-reader executor for {@link RunQueryUntilDone#readQueryResults()}.
*/
private static FrameProcessorExecutor createResultReaderExec(final String queryId)
{
return new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat("msq-result-reader[" + queryId + "]")))
);
}

/**
* Cancel any currently-running work and shut down a result-reader executor, like one created by
* {@link #createResultReaderExec(String)}.
*/
private static void closeResultReaderExec(final FrameProcessorExecutor exec)
{
try {
exec.cancel(RESULT_READER_CANCELLATION_ID);
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
exec.shutdownNow();
}
}

private void stopExternalFetchers()
{
if (workerSketchFetcher != null) {
Expand Down Expand Up @@ -2698,12 +2727,9 @@ private void startQueryResultsReader()
inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds);
}

final FrameProcessorExecutor resultReaderExec = new FrameProcessorExecutor(
MoreExecutors.listeningDecorator(
Execs.singleThreaded(StringUtils.encodeForFormat("msq-result-reader[" + queryId() + "]")))
);
final FrameProcessorExecutor resultReaderExec = createResultReaderExec(queryId());
resultReaderExec.registerCancellationId(RESULT_READER_CANCELLATION_ID);

final String cancellationId = "results-reader";
ReadableConcatFrameChannel resultsChannel = null;

try {
Expand All @@ -2713,7 +2739,7 @@ private void startQueryResultsReader()
inputChannelFactory,
() -> ArenaMemoryAllocator.createOnHeap(5_000_000),
resultReaderExec,
cancellationId,
RESULT_READER_CANCELLATION_ID,
null,
MultiStageQueryContext.removeNullBytes(querySpec.getQuery().context())
);
Expand Down Expand Up @@ -2747,7 +2773,7 @@ private void startQueryResultsReader()
queryListener
);

queryResultsReaderFuture = resultReaderExec.runFully(resultsReader, cancellationId);
queryResultsReaderFuture = resultReaderExec.runFully(resultsReader, RESULT_READER_CANCELLATION_ID);

// When results are done being read, kick the main thread.
// Important: don't use FutureUtils.futureWithBaggage, because we need queryResultsReaderFuture to resolve
Expand All @@ -2764,23 +2790,13 @@ private void startQueryResultsReader()
e,
() -> CloseableUtils.closeAll(
finalResultsChannel,
() -> resultReaderExec.getExecutorService().shutdownNow()
() -> closeResultReaderExec(resultReaderExec)
)
);
}

// Result reader is set up. Register with the query-wide closer.
closer.register(() -> {
try {
resultReaderExec.cancel(cancellationId);
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
resultReaderExec.getExecutorService().shutdownNow();
}
});
closer.register(() -> closeResultReaderExec(resultReaderExec));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -56,6 +57,7 @@
import org.apache.druid.frame.processor.manager.ProcessorManagers;
import org.apache.druid.frame.util.DurableStorageUtils;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
Expand All @@ -67,6 +69,8 @@
import org.apache.druid.msq.indexing.CountingOutputChannelFactory;
import org.apache.druid.msq.indexing.InputChannelFactory;
import org.apache.druid.msq.indexing.InputChannelsImpl;
import org.apache.druid.msq.indexing.error.CanceledFault;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.processor.KeyStatisticsCollectionProcessor;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSliceReader;
Expand Down Expand Up @@ -94,7 +98,6 @@
import org.apache.druid.msq.shuffle.output.DurableStorageOutputChannelFactory;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.utils.CloseableUtils;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;

import javax.annotation.Nullable;
Expand All @@ -104,15 +107,38 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

/**
* Main worker logic for executing a {@link WorkOrder} in a {@link FrameProcessorExecutor}.
*/
public class RunWorkOrder
{
private final String controllerTaskId;
enum State
{
/**
* Initial state. Must be in this state to call {@link #startAsync()}.
*/
INIT,

/**
* State entered upon calling {@link #startAsync()}.
*/
STARTED,

/**
* State entered upon calling {@link #stop()}.
*/
STOPPING,

/**
* State entered when a call to {@link #stop()} concludes.
*/
STOPPED
}

private final WorkOrder workOrder;
private final InputChannelFactory inputChannelFactory;
private final CounterTracker counterTracker;
Expand All @@ -125,7 +151,9 @@ public class RunWorkOrder
private final boolean reindex;
private final boolean removeNullBytes;
private final ByteTracker intermediateSuperSorterLocalStorageTracker;
private final AtomicBoolean started = new AtomicBoolean();
private final AtomicReference<State> state = new AtomicReference<>(State.INIT);
private final CountDownLatch stopLatch = new CountDownLatch(1);
private final AtomicReference<Either<Throwable, Object>> resultForListener = new AtomicReference<>();

@MonotonicNonNull
private InputSliceReader inputSliceReader;
Expand All @@ -141,7 +169,6 @@ public class RunWorkOrder
private ListenableFuture<OutputChannels> stageOutputChannelsFuture;

public RunWorkOrder(
final String controllerTaskId,
final WorkOrder workOrder,
final InputChannelFactory inputChannelFactory,
final CounterTracker counterTracker,
Expand All @@ -154,7 +181,6 @@ public RunWorkOrder(
final boolean removeNullBytes
)
{
this.controllerTaskId = controllerTaskId;
this.workOrder = workOrder;
this.inputChannelFactory = inputChannelFactory;
this.counterTracker = counterTracker;
Expand All @@ -180,15 +206,16 @@ public RunWorkOrder(
* Execution proceeds asynchronously after this method returns. The {@link RunWorkOrderListener} passed to the
* constructor of this instance can be used to track progress.
*/
public void start() throws IOException
public void startAsync()
{
if (started.getAndSet(true)) {
throw new ISE("Already started");
if (!state.compareAndSet(State.INIT, State.STARTED)) {
throw new ISE("Cannot start from state[%s]", state);
}

final StageDefinition stageDef = workOrder.getStageDefinition();

try {
exec.registerCancellationId(cancellationId);
makeInputSliceReader();
makeWorkOutputChannelFactory();
makeShuffleOutputChannelFactory();
Expand All @@ -205,16 +232,78 @@ public void start() throws IOException
setUpCompletionCallbacks();
}
catch (Throwable t) {
// If start() has problems, cancel anything that was already kicked off, and close the FrameContext.
stopUnchecked();
}
}

/**
* Stops an execution that was previously initiated through {@link #startAsync()} and closes the {@link FrameContext}.
* May be called to cancel execution. Must also be called after successful execution in order to ensure that resources
* are all properly cleaned up.
*
* Blocks until execution is fully stopped.
*/
public void stop() throws InterruptedException
{
if (state.compareAndSet(State.INIT, State.STOPPING)
|| state.compareAndSet(State.STARTED, State.STOPPING)) {
// Initiate stopping.
Throwable e = null;

try {
exec.cancel(cancellationId);
}
catch (Throwable t2) {
t.addSuppressed(t2);
catch (Throwable e2) {
e = e2;
}

CloseableUtils.closeAndSuppressExceptions(frameContext, t::addSuppressed);
throw t;
try {
frameContext.close();
}
catch (Throwable e2) {
if (e == null) {
e = e2;
} else {
e.addSuppressed(e2);
}
}

try {
// notifyListener will ignore this cancellation error if work has already succeeded.
notifyListener(Either.error(new MSQException(CanceledFault.instance())));
}
catch (Throwable e2) {
if (e == null) {
e = e2;
} else {
e.addSuppressed(e2);
}
}

stopLatch.countDown();

if (e != null) {
Throwables.throwIfInstanceOf(e, InterruptedException.class);
Throwables.throwIfUnchecked(e);
throw new RuntimeException(e);
}
}

stopLatch.await();
}

/**
* Calls {@link #stop()}. If the call to {@link #stop()} throws {@link InterruptedException}, this method sets
* the interrupt flag and throws an unchecked exception.
*/
public void stopUnchecked()
{
try {
stop();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}

Expand Down Expand Up @@ -459,19 +548,33 @@ public void onSuccess(final List<Object> workerResultAndOutputChannelsResolved)
writeDurableStorageSuccessFile();
}

listener.onSuccess(resultObject);
notifyListener(Either.value(resultObject));
}

@Override
public void onFailure(final Throwable t)
{
listener.onFailure(t);
notifyListener(Either.error(t));
}
},
Execs.directExecutor()
);
}

/**
* Notify {@link RunWorkOrderListener} that the job is done, if not already notified.
*/
private void notifyListener(final Either<Throwable, Object> result)
{
if (resultForListener.compareAndSet(null, result)) {
if (result.isError()) {
listener.onFailure(result.error());
} else {
listener.onSuccess(result.valueOrThrow());
}
}
}

/**
* Write {@link DurableStorageUtils#SUCCESS_MARKER_FILENAME} for a particular stage, if durable storage is enabled.
*/
Expand Down Expand Up @@ -561,7 +664,7 @@ private DurableStorageOutputChannelFactory makeDurableStorageOutputChannelFactor
)
{
return DurableStorageOutputChannelFactory.createStandardImplementation(
controllerTaskId,
workerContext.queryId(),
workOrder.getWorkerNumber(),
workOrder.getStageNumber(),
workerContext.workerId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import javax.annotation.Nullable;

/**
* Listener for various things that may happen during execution of {@link RunWorkOrder#start()}. Listener methods are
* Listener for various things that may happen during execution of {@link RunWorkOrder#startAsync()}. Listener methods are
* fired in processing threads, so they must be thread-safe, and it is important that they run quickly.
*/
public interface RunWorkOrderListener
Expand Down
Loading
Loading