Skip to content

Commit

Permalink
remove coordinator context
Browse files Browse the repository at this point in the history
  • Loading branch information
gang_ye committed May 1, 2023
1 parent 005616e commit bb332cc
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@
*/
package org.apache.iceberg.flink.sink.shuffle;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.flink.annotation.Internal;
import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.ThrowableCatchingRunnable;
import org.apache.flink.util.function.ThrowingRunnable;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -45,30 +54,29 @@
class DataStatisticsCoordinator<K> implements OperatorCoordinator {
private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class);

private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 80;

private final String operatorName;
private final DataStatisticsCoordinatorContext<K> context;
private final DataStatisticsFactory<K> statisticsFactory;
private final ExecutorService coordinatorExecutor;
private final OperatorCoordinator.Context operatorCoordinatorContext;
private final SubtaskGateways subtaskGateways;
private final DataStatisticsCoordinatorProvider.CoordinatorExecutorThreadFactory
coordinatorThreadFactory;
private final GlobalStatisticsAggregatorTracker<K> globalStatisticsAggregatorTracker;

private volatile GlobalStatisticsAggregator<K> inProgressAggregator;
private volatile GlobalStatisticsAggregator<K> lastCompletedAggregator;
private volatile boolean started;

DataStatisticsCoordinator(
String operatorName,
OperatorCoordinator.Context context,
DataStatisticsFactory<K> statisticsFactory) {
this.operatorName = operatorName;
DataStatisticsCoordinatorProvider.CoordinatorExecutorThreadFactory coordinatorThreadFactory =
this.coordinatorThreadFactory =
new DataStatisticsCoordinatorProvider.CoordinatorExecutorThreadFactory(
"DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader());
this.context =
new DataStatisticsCoordinatorContext<>(
Executors.newSingleThreadExecutor(coordinatorThreadFactory),
coordinatorThreadFactory,
context);
this.statisticsFactory = statisticsFactory;
this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory);
this.operatorCoordinatorContext = context;
this.subtaskGateways = new SubtaskGateways(parallelism());
this.globalStatisticsAggregatorTracker =
new GlobalStatisticsAggregatorTracker<>(statisticsFactory, parallelism());
}

@Override
Expand All @@ -80,14 +88,69 @@ public void start() throws Exception {
@Override
public void close() throws Exception {
LOG.info("Closing data statistics coordinator for {}.", operatorName);
context.close();
coordinatorExecutor.shutdown();
try {
if (!coordinatorExecutor.awaitTermination(5, TimeUnit.SECONDS)) {
LOG.warn(
"Fail to shut down data statistics coordinator context gracefully. Shutting down now");
coordinatorExecutor.shutdownNow();
if (!coordinatorExecutor.awaitTermination(5, TimeUnit.SECONDS)) {
LOG.warn("Fail to terminate data statistics coordinator context");
return;
}
}
LOG.info("Data statistics coordinator context closed.");
} catch (InterruptedException e) {
coordinatorExecutor.shutdownNow();
Thread.currentThread().interrupt();
LOG.error("Errors occurred while closing the data statistics coordinator context", e);
}

LOG.info("Data statistics coordinator for {} closed.", operatorName);
}

void callInCoordinatorThread(Callable<Void> callable, String errorMessage) {
// Ensure the task is done by the coordinator executor.
if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) {
try {
final Callable<Void> guardedCallable =
() -> {
try {
return callable.call();
} catch (Throwable t) {
LOG.error("Uncaught Exception in DataStatistics Coordinator Executor", t);
ExceptionUtils.rethrowException(t);
return null;
}
};

coordinatorExecutor.submit(guardedCallable).get();
} catch (InterruptedException | ExecutionException e) {
throw new FlinkRuntimeException(errorMessage, e);
}
} else {
try {
callable.call();
} catch (Throwable t) {
LOG.error("Uncaught Exception in DataStatistics coordinator executor", t);
throw new FlinkRuntimeException(errorMessage, t);
}
}
}

public void runInCoordinatorThread(Runnable runnable) {
this.coordinatorExecutor.execute(
new ThrowableCatchingRunnable(
(throwable) -> {
this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), throwable);
},
runnable));
}

private void runInCoordinatorThread(
ThrowingRunnable<Throwable> action, String actionName, Object... actionNameFormatParameters) {
ensureStarted();
context.runInCoordinatorThread(
runInCoordinatorThread(
() -> {
try {
action.run();
Expand All @@ -99,8 +162,7 @@ private void runInCoordinatorThread(
operatorName,
actionString,
t);

this.context.failJob(t);
operatorCoordinatorContext.failJob(t);
}
});
}
Expand All @@ -109,73 +171,33 @@ private void ensureStarted() {
Preconditions.checkState(started, "The coordinator has not started yet.");
}

private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<K> event) {
long checkpointId = event.checkpointId();

if (lastCompletedAggregator != null && lastCompletedAggregator.checkpointId() >= checkpointId) {
LOG.debug(
"Data statistics aggregation for checkpoint {} has completed. Ignore the event from subtask {} for checkpoint {}",
lastCompletedAggregator.checkpointId(),
subtask,
checkpointId);
return;
}

if (inProgressAggregator == null) {
inProgressAggregator = new GlobalStatisticsAggregator<>(checkpointId, statisticsFactory);
}
private int parallelism() {
return operatorCoordinatorContext.currentParallelism();
}

if (inProgressAggregator.checkpointId() < checkpointId) {
if ((double) inProgressAggregator.aggregatedSubtasksCount() / context.parallelism() * 100
>= EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE) {
lastCompletedAggregator = inProgressAggregator;
LOG.info(
"Received data statistics from {} operators out of total {} for checkpoint {}. "
+ "It's more than the expected percentage {}. Sending the aggregate data"
+ " statistics {} to subtasks.",
inProgressAggregator.aggregatedSubtasksCount(),
context.parallelism(),
inProgressAggregator.checkpointId(),
EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE,
lastCompletedAggregator);
inProgressAggregator = new GlobalStatisticsAggregator<>(checkpointId, statisticsFactory);
inProgressAggregator.mergeDataStatistic(subtask, event);
context.sendDataStatisticsToSubtasks(
inProgressAggregator.checkpointId(), lastCompletedAggregator.dataStatistics());
return;
} else {
LOG.info(
"Received data statistics from {} operators out of total {} for checkpoint {}. "
+ "It's less than the expected percentage {}. Dropping the incomplete aggregate "
+ "data statistics {} and starting collecting data statistics from new checkpoint {}",
inProgressAggregator.aggregatedSubtasksCount(),
context.parallelism(),
inProgressAggregator.checkpointId(),
EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE,
inProgressAggregator,
checkpointId);
inProgressAggregator = new GlobalStatisticsAggregator<>(checkpointId, statisticsFactory);
}
} else if (inProgressAggregator.checkpointId() > checkpointId) {
LOG.debug(
"Expect data statistics for checkpoint {}, but receive event from older checkpoint {}. Ignore it.",
inProgressAggregator.checkpointId(),
checkpointId);
return;
private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<K> event) {
if (globalStatisticsAggregatorTracker.receiveDataStatisticEventAndCheckCompletion(
subtask, event)) {
GlobalStatisticsAggregator<K> lastCompletedAggregator =
globalStatisticsAggregatorTracker.lastCompletedAggregator();
sendDataStatisticsToSubtasks(
lastCompletedAggregator.checkpointId(), lastCompletedAggregator.dataStatistics());
}
}

inProgressAggregator.mergeDataStatistic(subtask, event);

if (inProgressAggregator.aggregatedSubtasksCount() == context.parallelism()) {
lastCompletedAggregator = inProgressAggregator;
LOG.info(
"Received data statistics from all {} operators for checkpoint {}. Sending the aggregated data statistics {} to subtasks.",
context.parallelism(),
inProgressAggregator.checkpointId(),
lastCompletedAggregator.dataStatistics());
inProgressAggregator = null;
context.sendDataStatisticsToSubtasks(checkpointId, lastCompletedAggregator.dataStatistics());
}
private void sendDataStatisticsToSubtasks(
long checkpointId, DataStatistics<K> globalDataStatistics) {
callInCoordinatorThread(
() -> {
DataStatisticsEvent<K> dataStatisticsEvent =
new DataStatisticsEvent<>(checkpointId, globalDataStatistics);
int parallelism = parallelism();
for (int i = 0; i < parallelism; ++i) {
subtaskGateways.getOnlyGatewayAndCheckReady(i).sendEvent(dataStatisticsEvent);
}
return null;
},
String.format("Failed to send global data statistics for checkpoint %d", checkpointId));
}

@Override
Expand All @@ -192,7 +214,7 @@ public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEven
Preconditions.checkArgument(event instanceof DataStatisticsEvent);
handleDataStatisticRequest(subtask, ((DataStatisticsEvent<K>) event));
},
"handling operator event %s from data statistics operator subtask %d (#%d)",
"handling operator event %s from subtask %d (#%d)",
event.getClass(),
subtask,
attemptNumber);
Expand All @@ -208,7 +230,8 @@ public void checkpointCoordinator(long checkpointId, CompletableFuture<byte[]> r
checkpointId);
try {
byte[] serializedDataDistributionWeight =
InstantiationUtil.serializeObject(lastCompletedAggregator);
InstantiationUtil.serializeObject(
globalStatisticsAggregatorTracker.lastCompletedAggregator());
resultFuture.complete(serializedDataDistributionWeight);
} catch (Throwable e) {
ExceptionUtils.rethrowIfFatalErrorOrOOM(e);
Expand Down Expand Up @@ -241,9 +264,9 @@ public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData

LOG.info(
"Restoring data statistic coordinator {} from checkpoint {}.", operatorName, checkpointId);
lastCompletedAggregator =
globalStatisticsAggregatorTracker.setLastCompletedAggregator(
InstantiationUtil.deserializeObject(
checkpointData, GlobalStatisticsAggregator.class.getClassLoader());
checkpointData, GlobalStatisticsAggregator.class.getClassLoader()));
}

@Override
Expand All @@ -255,7 +278,9 @@ public void subtaskReset(int subtask, long checkpointId) {
subtask,
checkpointId,
operatorName);
context.subtaskReset(subtask);
Preconditions.checkState(
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
subtaskGateways.reset(subtask);
},
"handling subtask %d recovery to checkpoint %d",
subtask,
Expand All @@ -271,7 +296,9 @@ public void executionAttemptFailed(int subtask, int attemptNumber, @Nullable Thr
subtask,
attemptNumber,
operatorName);
context.attemptFailed(subtask, attemptNumber);
Preconditions.checkState(
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
subtaskGateways.unregisterSubtaskGateway(subtask, attemptNumber);
},
"handling subtask %d (#%d) failure",
subtask,
Expand All @@ -284,26 +311,58 @@ public void executionAttemptReady(int subtask, int attemptNumber, SubtaskGateway
Preconditions.checkArgument(attemptNumber == gateway.getExecution().getAttemptNumber());
runInCoordinatorThread(
() -> {
context.attemptReady(gateway);
Preconditions.checkState(
this.coordinatorThreadFactory.isCurrentThreadCoordinatorThread());
subtaskGateways.registerSubtaskGateway(gateway);
},
"making event gateway to subtask %d (#%d) available",
subtask,
attemptNumber);
}

// ---------------------------------------------------
@VisibleForTesting
GlobalStatisticsAggregator<K> completeAggregatedDataStatistics() {
return lastCompletedAggregator;
GlobalStatisticsAggregatorTracker<K> globalStatisticsAggregatorTracker() {
return globalStatisticsAggregatorTracker;
}

@VisibleForTesting
GlobalStatisticsAggregator<K> incompleteAggregatedDataStatistics() {
return inProgressAggregator;
}
private static class SubtaskGateways {
private final Map<Integer, SubtaskGateway>[] gateways;

@VisibleForTesting
DataStatisticsCoordinatorContext<K> context() {
return context;
private SubtaskGateways(int parallelism) {
gateways = new Map[parallelism];

for (int i = 0; i < parallelism; ++i) {
gateways[i] = Maps.newHashMap();
}
}

private void registerSubtaskGateway(OperatorCoordinator.SubtaskGateway gateway) {
int subtaskIndex = gateway.getSubtask();
int attemptNumber = gateway.getExecution().getAttemptNumber();
Preconditions.checkState(
!gateways[subtaskIndex].containsKey(attemptNumber),
"Already have a subtask gateway for %d (#%d).",
subtaskIndex,
attemptNumber);
LOG.debug("Register gateway for subtask {} attempt {}", subtaskIndex, attemptNumber);
gateways[subtaskIndex].put(attemptNumber, gateway);
}

private void unregisterSubtaskGateway(int subtaskIndex, int attemptNumber) {
LOG.debug("Unregister gateway for subtask {} attempt {}", subtaskIndex, attemptNumber);
gateways[subtaskIndex].remove(attemptNumber);
}

private OperatorCoordinator.SubtaskGateway getOnlyGatewayAndCheckReady(int subtaskIndex) {
Preconditions.checkState(
gateways[subtaskIndex].size() > 0,
"Subtask %d is not ready yet to receive events.",
subtaskIndex);
return Iterables.getOnlyElement(gateways[subtaskIndex].values());
}

private void reset(int subtaskIndex) {
gateways[subtaskIndex].clear();
}
}
}
Loading

0 comments on commit bb332cc

Please sign in to comment.