Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23033][SS][Follow Up] Task level retry for continuous processing #20675

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ class KafkaContinuousDataReader(
KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset)
}

override def setOffset(offset: PartitionOffset): Unit = {
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
assert(
kafkaOffset.topicPartition == topicPartition)
nextKafkaOffset = kafkaOffset.partitionOffset
}

override def close(): Unit = {
consumer.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,16 @@ public interface ContinuousDataReader<T> extends DataReader<T> {
* as a restart checkpoint.
*/
PartitionOffset getOffset();

/**
* Set the start offset for the current record, only used in task retry. If setOffset keep
* default implementation, it means current ContinuousDataReader can't support task level retry.
*
* @param offset last offset before task retry.
*/
default void setOffset(PartitionOffset offset) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to create a new interface ContinuousDataReaderFactory, and implement this there as something like createDataReaderWithOffset(PartitionOffset offset). That way the intended lifecycle is explicit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, that's more clearer.

throw new UnsupportedOperationException(
"Current ContinuousDataReader can't support setOffset, task will restart " +
"with checkpoints in ContinuousExecution.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,24 @@ class ContinuousDataSourceRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
// If attempt number isn't 0, this is a task retry, which we don't support.
if (context.attemptNumber() != 0) {
throw new ContinuousTaskRetryException()
}
val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)

val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]]
.readerFactory.createDataReader()
var lastEpoch: Option[Long] = None

val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
// If attempt number isn't 0, this is a task retry, we should get last offset and epoch.
if (context.attemptNumber() != 0) {
val lastEpochAndOffset = epochEndpoint.askSync[Option[(Long, PartitionOffset)]](
GetLastEpochAndOffset(context.partitionId()))
// If GetLastEpochAndOffset return None, it means task failed while created, just
// restart from the initial offset and epoch.
if (lastEpochAndOffset.isDefined) {
ContinuousDataSourceRDD.getBaseReader(reader).setOffset(lastEpochAndOffset.get._2)
lastEpoch = Some(lastEpochAndOffset.get._1)
}
}

// This queue contains two types of messages:
// * (null, null) representing an epoch boundary.
Expand All @@ -83,14 +92,13 @@ class ContinuousDataSourceRDD(
epochPollExecutor.shutdown()
})

val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
new Iterator[UnsafeRow] {
private val POLL_TIMEOUT_MS = 1000

private var currentEntry: (UnsafeRow, PartitionOffset) = _
private var currentOffset: PartitionOffset = startOffset
private var currentEpoch =
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
lastEpoch.getOrElse(context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)

override def hasNext(): Boolean = {
while (currentEntry == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,11 @@ class RateStreamContinuousDataReader(

override def getOffset(): PartitionOffset =
RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime)

override def setOffset(offset: PartitionOffset): Unit = {
val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
assert(rateStreamOffset.partition == partitionIndex)
currentValue = rateStreamOffset.currentValue
nextReadTime = rateStreamOffset.currentTimeMs
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ private[sql] case class ReportPartitionOffset(
epoch: Long,
offset: PartitionOffset) extends EpochCoordinatorMessage

/**
* Get last epoch and offset of particular partition, only used in task retry.
*/
private[sql] case class GetLastEpochAndOffset(partitionId: Int) extends EpochCoordinatorMessage


/** Helper object used to create reference to [[EpochCoordinator]]. */
private[sql] object EpochCoordinatorRef extends Logging {
Expand Down Expand Up @@ -205,5 +210,12 @@ private[continuous] class EpochCoordinator(
case StopContinuousExecutionWrites =>
queryWritesStopped = true
context.reply(())

case GetLastEpochAndOffset(partitionId) =>
val epochAndOffset = partitionOffsets.collect {
case ((e, p), o) if p == partitionId => (e, o)
}.toSeq.sortBy(_._1).lastOption
logDebug(s"Get last epoch and offset of partitionId($partitionId): $epochAndOffset")
context.reply(epochAndOffset)
}
}
13 changes: 13 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,19 @@ object QueryTest {
None
}

def includesRowsOnlyOnce(
expectedRows: Seq[Row],
sparkAnswer: Seq[Row]): Option[String] = {
val expectedAnswer = prepareAnswer(expectedRows, true)
val actualAnswer = prepareAnswer(sparkAnswer, true)
val diffRow = actualAnswer.diff(expectedAnswer)
if (!expectedAnswer.toSet.subsetOf(actualAnswer.toSet)
|| diffRow.intersect(expectedAnswer).nonEmpty) {
return Some(genError(expectedRows, sparkAnswer, true))
}
None
}

def sameRows(
expectedAnswer: Seq[Row],
sparkAnswer: Seq[Row],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
}

case class CheckAnswerRowsContainsOnlyOnce(expectedAnswer: Seq[Row], lastOnly: Boolean = false)
extends StreamAction with StreamMustBeRunning {
override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}"
private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
}

case class CheckAnswerRowsByFunc(
globalCheckFunction: Seq[Row] => Unit,
lastOnly: Boolean) extends StreamAction with StreamMustBeRunning {
Expand Down Expand Up @@ -678,6 +684,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
error => failTest(error)
}

case CheckAnswerRowsContainsOnlyOnce(expectedAnswer, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
QueryTest.includesRowsOnlyOnce(expectedAnswer, sparkAnswer).foreach {
error => failTest(error)
}

case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ContinuousSuite extends ContinuousSuiteBase {
StopStream)
}

test("task failure kills the query") {
test("task restart") {
val df = spark.readStream
.format("rate")
.option("numPartitions", "5")
Expand All @@ -219,18 +219,59 @@ class ContinuousSuite extends ContinuousSuiteBase {
spark.sparkContext.addSparkListener(listener)
try {
testStream(df, useV2Sink = true)(
StartStream(Trigger.Continuous(100)),
StartStream(longContinuousTrigger),
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 2)),
IncrementEpoch(),
Execute { _ =>
// Wait until a task is started, then kill its first attempt.
eventually(timeout(streamingTimeout)) {
assert(taskId != -1)
}
spark.sparkContext.killTaskAttempt(taskId)
},
ExpectFailure[SparkException] { e =>
e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException]
})
Execute(waitForRateSourceTriggers(_, 4)),
IncrementEpoch(),
// Check the answer exactly, if there's duplicated result, CheckAnserRowsContains
// will also return true.
CheckAnswerRowsContainsOnlyOnce(scala.Range(0, 20).map(Row(_))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking exact answer can just be CheckAnswer(0 to 20: _*).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I firstly use CheckAnswer(0 to 19: _*) here, but I found the test case failure probably because the CP maybe not stop between Range(0, 20) every time. See the logs below:

== Plan ==
== Parsed Logical Plan ==
WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MemoryStreamWriter@6435422d
+- Project [value#13L]
   +- StreamingDataSourceV2Relation [timestamp#12, value#13L], org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader@5c5d9c45

== Analyzed Logical Plan ==
WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MemoryStreamWriter@6435422d
+- Project [value#13L]
   +- StreamingDataSourceV2Relation [timestamp#12, value#13L], org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader@5c5d9c45

== Optimized Logical Plan ==
WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MemoryStreamWriter@6435422d
+- Project [value#13L]
   +- StreamingDataSourceV2Relation [timestamp#12, value#13L], org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader@5c5d9c45

== Physical Plan ==
WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MemoryStreamWriter@6435422d
+- *(1) Project [value#13L]
   +- *(1) DataSourceV2Scan [timestamp#12, value#13L], org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader@5c5d9c45
         
         
ScalaTestFailureLocation: org.apache.spark.sql.streaming.StreamTest$class at (StreamTest.scala:436)
org.scalatest.exceptions.TestFailedException: 

== Results ==
!== Correct Answer - 20 ==   == Spark Answer - 25 ==
!struct<value:int>           struct<value:bigint>
 [0]                         [0]
 [10]                        [10]
 [11]                        [11]
 [12]                        [12]
 [13]                        [13]
 [14]                        [14]
 [15]                        [15]
 [16]                        [16]
 [17]                        [17]
 [18]                        [18]
 [19]                        [19]
 [1]                         [1]
![2]                         [20]
![3]                         [21]
![4]                         [22]
![5]                         [23]
![6]                         [24]
![7]                         [2]
![8]                         [3]
![9]                         [4]
!                            [5]
!                            [6]
!                            [7]
!                            [8]
!                            [9]
    

== Progress ==
   StartStream(ContinuousTrigger(3600000),org.apache.spark.util.SystemClock@343e225a,Map(),null)
   AssertOnQuery(<condition>, )
   AssertOnQuery(<condition>, )
   AssertOnQuery(<condition>, )
   AssertOnQuery(<condition>, )
   AssertOnQuery(<condition>, )
   AssertOnQuery(<condition>, )
=> CheckAnswer: [0],[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11],[12],[13],[14],[15],[16],[17],[18],[19]
   StopStream

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, my bad.

StopStream)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}

test("task restart without last offset") {
val df = spark.readStream
.format("rate")
.option("numPartitions", "5")
.option("rowsPerSecond", "5")
.load()
.select('value)

// Get an arbitrary task from this query to kill. It doesn't matter which one.
var taskId: Long = -1
val listener = new SparkListener() {
override def onTaskStart(start: SparkListenerTaskStart): Unit = {
taskId = start.taskInfo.taskId
}
}
spark.sparkContext.addSparkListener(listener)
try {
testStream(df, useV2Sink = true)(
StartStream(longContinuousTrigger),
Execute { _ =>
// Wait until a task is started, then kill its first attempt.
eventually(timeout(streamingTimeout)) {
assert(taskId != -1)
}
spark.sparkContext.killTaskAttempt(taskId)
},
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 2)),
IncrementEpoch(),
CheckAnswerRowsContainsOnlyOnce(scala.Range(0, 10).map(Row(_))),
StopStream)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
Expand Down