diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 5d68b752fa463..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -266,7 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { - epochEndpoint.askSync(StopEpochCoordinator) + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 0a231eb44c761..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -40,11 +40,13 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage /** - * Synchronously stop the epoch coordinator. The RpcEndpoint stop() will clear out the message queue - * before terminating the endpoint, but we must be sure no more messages will be processed before we - * can restart the query. The framework unfortunately provides no handle to wait for the queue. + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. */ -private[sql] case object StopEpochCoordinator extends EpochCoordinatorMessage +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage // Init messages /** @@ -123,7 +125,7 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var stopped: Boolean = false + private var queryWritesStopped: Boolean = false private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -162,7 +164,10 @@ private[continuous] class EpochCoordinator( } override def receive: PartialFunction[Any, Unit] = { - case _ if stopped => throw new IllegalStateException(s"Coordinator $this stopped") + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -182,7 +187,6 @@ private[continuous] class EpochCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case _ if stopped => throw new IllegalStateException(s"Coordinator $this stopped") case GetCurrentEpoch => val result = currentDriverEpoch logDebug(s"Epoch $result") @@ -200,8 +204,8 @@ private[continuous] class EpochCoordinator( numWriterPartitions = numPartitions context.reply(()) - case StopEpochCoordinator => - stopped = true + case StopContinuousExecutionWrites => + queryWritesStopped = true context.reply(()) } }