From 21f574e2a3ad3c8e68b92776d2a141d7fcb90502 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 26 Feb 2018 15:27:10 +0800 Subject: [PATCH] [SPARK-23033][SS][Follow Up] Task level retry for continuous processing --- .../sql/kafka010/KafkaContinuousReader.scala | 7 +++ .../streaming/ContinuousDataReader.java | 12 +++++ .../ContinuousDataSourceRDDIter.scala | 22 +++++--- .../ContinuousRateStreamSource.scala | 7 +++ .../continuous/EpochCoordinator.scala | 12 +++++ .../org/apache/spark/sql/QueryTest.scala | 13 +++++ .../spark/sql/streaming/StreamTest.scala | 12 +++++ .../continuous/ContinuousSuite.scala | 51 +++++++++++++++++-- 8 files changed, 124 insertions(+), 12 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index ecd1170321f3f..ac02c17739ab0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -235,6 +235,13 @@ class KafkaContinuousDataReader( KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) } + override def setOffset(offset: PartitionOffset): Unit = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + assert( + kafkaOffset.topicPartition == topicPartition) + nextKafkaOffset = kafkaOffset.partitionOffset + } + override def close(): Unit = { consumer.close() } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java index 47d26440841fd..8bb922fd26e76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java @@ -33,4 +33,16 @@ public interface ContinuousDataReader extends DataReader { * as a restart checkpoint. */ PartitionOffset getOffset(); + + /** + * Set the start offset for the current record, only used in task retry. If setOffset keep + * default implementation, it means current ContinuousDataReader can't support task level retry. + * + * @param offset last offset before task retry. + */ + default void setOffset(PartitionOffset offset) { + throw new UnsupportedOperationException( + "Current ContinuousDataReader can't support setOffset, task will restart " + + "with checkpoints in ContinuousExecution."); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cf02c0dda25d7..bad0d0e69eeb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -48,15 +48,24 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } + val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] .readerFactory.createDataReader() + var lastEpoch: Option[Long] = None - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + // If attempt number isn't 0, this is a task retry, we should get last offset and epoch. + if (context.attemptNumber() != 0) { + val lastEpochAndOffset = epochEndpoint.askSync[Option[(Long, PartitionOffset)]]( + GetLastEpochAndOffset(context.partitionId())) + // If GetLastEpochAndOffset return None, it means task failed while created, just + // restart from the initial offset and epoch. + if (lastEpochAndOffset.isDefined) { + ContinuousDataSourceRDD.getBaseReader(reader).setOffset(lastEpochAndOffset.get._2) + lastEpoch = Some(lastEpochAndOffset.get._1) + } + } // This queue contains two types of messages: // * (null, null) representing an epoch boundary. @@ -83,14 +92,13 @@ class ContinuousDataSourceRDD( epochPollExecutor.shutdown() }) - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) new Iterator[UnsafeRow] { private val POLL_TIMEOUT_MS = 1000 private var currentEntry: (UnsafeRow, PartitionOffset) = _ private var currentOffset: PartitionOffset = startOffset private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + lastEpoch.getOrElse(context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) override def hasNext(): Boolean = { while (currentEntry == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b63d8d3e20650..5c14091175688 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -152,4 +152,11 @@ class RateStreamContinuousDataReader( override def getOffset(): PartitionOffset = RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) + + override def setOffset(offset: PartitionOffset): Unit = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + assert(rateStreamOffset.partition == partitionIndex) + currentValue = rateStreamOffset.currentValue + nextReadTime = rateStreamOffset.currentTimeMs + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..1374efbb63756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -73,6 +73,11 @@ private[sql] case class ReportPartitionOffset( epoch: Long, offset: PartitionOffset) extends EpochCoordinatorMessage +/** + * Get last epoch and offset of particular partition, only used in task retry. + */ +private[sql] case class GetLastEpochAndOffset(partitionId: Int) extends EpochCoordinatorMessage + /** Helper object used to create reference to [[EpochCoordinator]]. */ private[sql] object EpochCoordinatorRef extends Logging { @@ -205,5 +210,12 @@ private[continuous] class EpochCoordinator( case StopContinuousExecutionWrites => queryWritesStopped = true context.reply(()) + + case GetLastEpochAndOffset(partitionId) => + val epochAndOffset = partitionOffsets.collect { + case ((e, p), o) if p == partitionId => (e, o) + }.toSeq.sortBy(_._1).lastOption + logDebug(s"Get last epoch and offset of partitionId($partitionId): $epochAndOffset") + context.reply(epochAndOffset) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9fb8be423614b..cd42c609c7829 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -332,6 +332,19 @@ object QueryTest { None } + def includesRowsOnlyOnce( + expectedRows: Seq[Row], + sparkAnswer: Seq[Row]): Option[String] = { + val expectedAnswer = prepareAnswer(expectedRows, true) + val actualAnswer = prepareAnswer(sparkAnswer, true) + val diffRow = actualAnswer.diff(expectedAnswer) + if (!expectedAnswer.toSet.subsetOf(actualAnswer.toSet) + || diffRow.intersect(expectedAnswer).nonEmpty) { + return Some(genError(expectedRows, sparkAnswer, true)) + } + None + } + def sameRows( expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 08f722ecb10e5..22f1149c9660b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -194,6 +194,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } + case class CheckAnswerRowsContainsOnlyOnce(expectedAnswer: Seq[Row], lastOnly: Boolean = false) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + } + case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { @@ -678,6 +684,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } + case CheckAnswerRowsContainsOnlyOnce(expectedAnswer, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.includesRowsOnlyOnce(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..d0fdbaa545950 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -201,7 +201,7 @@ class ContinuousSuite extends ContinuousSuiteBase { StopStream) } - test("task failure kills the query") { + test("task restart") { val df = spark.readStream .format("rate") .option("numPartitions", "5") @@ -219,8 +219,10 @@ class ContinuousSuite extends ContinuousSuiteBase { spark.sparkContext.addSparkListener(listener) try { testStream(df, useV2Sink = true)( - StartStream(Trigger.Continuous(100)), + StartStream(longContinuousTrigger), + AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -228,9 +230,48 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.killTaskAttempt(taskId) }, - ExpectFailure[SparkException] { e => - e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException] - }) + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + // Check the answer exactly, if there's duplicated result, CheckAnserRowsContains + // will also return true. + CheckAnswerRowsContainsOnlyOnce(scala.Range(0, 20).map(Row(_))), + StopStream) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + + test("task restart without last offset") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + // Get an arbitrary task from this query to kill. It doesn't matter which one. + var taskId: Long = -1 + val listener = new SparkListener() { + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + taskId = start.taskInfo.taskId + } + } + spark.sparkContext.addSparkListener(listener) + try { + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + Execute { _ => + // Wait until a task is started, then kill its first attempt. + eventually(timeout(streamingTimeout)) { + assert(taskId != -1) + } + spark.sparkContext.killTaskAttempt(taskId) + }, + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContainsOnlyOnce(scala.Range(0, 10).map(Row(_))), + StopStream) } finally { spark.sparkContext.removeSparkListener(listener) }