Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23033][SS] Don't use task level retry for continuous processing #20225

Closed
wants to merge 15 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -808,16 +808,14 @@ class KafkaSourceSuiteBase extends KafkaSourceTest {
val query = kafka
.writeStream
.format("memory")
.outputMode("append")
.queryName("kafkaColumnTypes")
.trigger(defaultTrigger)
.start()
var rows: Array[Row] = Array()
eventually(timeout(streamingTimeout)) {
rows = spark.table("kafkaColumnTypes").collect()
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
assert(spark.table("kafkaColumnTypes").count == 1,
s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}")
}
val row = rows(0)
val row = spark.table("kafkaColumnTypes").head()
assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row")
assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row")
assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class ContinuousDataSourceRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
// If attempt number isn't 0, this is a task retry, which we don't support.
if (context.attemptNumber() != 0) {
throw new ContinuousTaskRetryException()
}

val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()

val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.function.UnaryOperator
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}

import org.apache.spark.SparkEnv
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark.SparkException

/**
* An exception thrown when a continuous processing task runs with a nonzero attempt ID.
*/
class ContinuousTaskRetryException
extends SparkException("Continuous execution does not support task retry", null)
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
currentStream.awaitInitialization(streamingTimeout.toMillis)
currentStream match {
case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
s.lastExecution.executedPlan // will fail if lastExecution is null
}
s.lastExecution.executedPlan // will fail if lastExecution is null
}
case _ =>
}
} catch {
Expand Down Expand Up @@ -645,7 +645,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}

case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
val sparkAnswer = currentStream match {
case null => fetchStreamAnswer(lastStream, lastOnly)
case s => fetchStreamAnswer(s, lastOnly)
}
QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
error => failTest(error)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,18 @@

package org.apache.spark.sql.streaming.continuous

import java.io.{File, InterruptedIOException, IOException, UncheckedIOException}
import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit}
import java.util.UUID

import scala.reflect.ClassTag
import scala.util.control.ControlThrowable

import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.{SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.TestSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class ContinuousSuiteBase extends StreamTest {
// We need more than the default local[2] to be able to schedule all partitions simultaneously.
Expand Down Expand Up @@ -219,6 +201,41 @@ class ContinuousSuite extends ContinuousSuiteBase {
StopStream)
}

test("task failure kills the query") {
val df = spark.readStream
.format("rate")
.option("numPartitions", "5")
.option("rowsPerSecond", "5")
.load()
.select('value)

// Get an arbitrary task from this query to kill. It doesn't matter which one.
var taskId: Long = -1
val listener = new SparkListener() {
override def onTaskStart(start: SparkListenerTaskStart): Unit = {
taskId = start.taskInfo.taskId
}
}
spark.sparkContext.addSparkListener(listener)
try {
testStream(df, useV2Sink = true)(
StartStream(Trigger.Continuous(100)),
Execute(waitForRateSourceTriggers(_, 2)),
Execute { _ =>
// Wait until a task is started, then kill its first attempt.
eventually(timeout(streamingTimeout)) {
assert(taskId != -1)
}
spark.sparkContext.killTaskAttempt(taskId)
},
ExpectFailure[SparkException] { e =>
e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException]
})
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}

test("query without test harness") {
val df = spark.readStream
.format("rate")
Expand Down Expand Up @@ -258,13 +275,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 201)),
IncrementEpoch(),
Execute { query =>
val data = query.sink.asInstanceOf[MemorySinkV2].allData
val vals = data.map(_.getLong(0)).toSet
assert(scala.Range(0, 25000).forall { i =>
vals.contains(i)
})
})
StopStream,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the earlier comment about the overloaded failure mode this PR exposed.

CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))
)
}

test("automatic epoch advancement") {
Expand All @@ -280,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 201)),
IncrementEpoch(),
StopStream,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these needed?

CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))
}

Expand Down Expand Up @@ -311,6 +325,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
StopStream,
StartStream(Trigger.Continuous(2012)),
AwaitEpoch(50),
StopStream,
CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))
}
}