Skip to content

Commit

Permalink
[SPARK-49326][SS] Classify Error class for Foreach sink user function…
Browse files Browse the repository at this point in the history
… error

### What changes were proposed in this pull request?

Similar with classification that micheal-o  did for ForeachBatch sink PR: #45299, any exception can be thrown from the user provided function for ForEach Sink. We want to classify this class of errors. Including errors from Python (Py4JException) and Scala functions.

### Why are the changes needed?

The user provided function can throw any type of error. Using the new error framework for better error messages and classification.

### Does this PR introduce _any_ user-facing change?

Yes, better error message with error class for Foreach sink user function failures.

### How was this patch tested?

Updated existing tests. Covers Python and Scala.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47819 from jingz-db/classify-foreach-error.

Authored-by: jingz-db <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
jingz-db authored and HeartSaVioR committed Sep 5, 2024
1 parent d75f550 commit a6203cc
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 11 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,12 @@
],
"sqlState" : "39000"
},
"FOREACH_USER_FUNCTION_ERROR" : {
"message" : [
"An error occurred in the user provided function in foreach sink. Reason: <reason>"
],
"sqlState" : "39000"
},
"FOUND_MULTIPLE_DATA_SOURCES" : {
"message" : [
"Detected multiple data sources with the name '<provider>'. Please check the data source isn't simultaneously registered and located in the classpath."
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/tests/streaming/test_streaming_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ def close(self, error):
try:
tester.run_streaming_query_on_writer(ForeachWriter(), 1)
self.fail("bad writer did not fail the query") # this is not expected
except StreamingQueryException:
# TODO: Verify whether original error message is inside the exception
pass
except StreamingQueryException as e:
err_msg = str(e)
self.assertTrue("test error" in err_msg)
self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg)

self.assertEqual(len(tester.process_events()), 0) # no row was processed
close_events = tester.close_events()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ import java.io.File
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock

import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
import scala.util.control.NonFatal

import org.apache.spark.{JobArtifactSet, SparkEnv, SparkThrowable, TaskContext}
import org.apache.spark.api.python._
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.sources.ForeachUserFuncException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{NextIterator, Utils}
Expand All @@ -53,6 +56,8 @@ class WriterThread(outputIterator: Iterator[Array[Byte]])
} catch {
// Cache exceptions seen while evaluating the Python function on the streamed input. The
// parent thread will throw this crashed exception eventually.
case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
_exception = ForeachUserFuncException(e)
case t: Throwable =>
_exception = t
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write}
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException
import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncException, ForeachUserFuncException}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
import org.apache.spark.sql.streaming._
Expand Down Expand Up @@ -346,9 +346,11 @@ abstract class StreamExecution(
getLatestExecutionContext().updateStatusMessage("Stopped")
case e: Throwable =>
val message = if (e.getMessage == null) "" else e.getMessage
val cause = if (e.isInstanceOf[ForeachBatchUserFuncException]) {
val cause = if (e.isInstanceOf[ForeachBatchUserFuncException] ||
e.isInstanceOf[ForeachUserFuncException]) {
// We want to maintain the current way users get the causing exception
// from the StreamingQueryException. Hence the ForeachBatch exception is unwrapped here.
// from the StreamingQueryException.
// Hence the ForeachBatch/Foreach exception is unwrapped here.
e.getCause
} else {
e
Expand Down Expand Up @@ -728,6 +730,7 @@ object StreamExecution {
if e2.getCause != null =>
isInterruptionException(e2.getCause, sc)
case fe: ForeachBatchUserFuncException => isInterruptionException(fe.getCause, sc)
case fes: ForeachUserFuncException => isInterruptionException(fes.getCause, sc)
case se: SparkException =>
if (se.getCause == null) {
isCancelledJobGroup(se.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ package org.apache.spark.sql.execution.streaming.sources

import java.util

import scala.util.control.NonFatal

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.{ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -146,6 +149,9 @@ class ForeachDataWriter[T](
try {
writer.process(rowConverter(record))
} catch {
case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
errorOrNull = e
throw ForeachUserFuncException(e)
case t: Throwable =>
errorOrNull = t
throw t
Expand All @@ -172,3 +178,12 @@ class ForeachDataWriter[T](
* An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
*/
case object ForeachWriterCommitMessage extends WriterCommitMessage

/**
* Exception that wraps the exception thrown in the user provided function in Foreach sink.
*/
private[sql] case class ForeachUserFuncException(cause: Throwable)
extends SparkException(
errorClass = "FOREACH_USER_FUNCTION_ERROR",
messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")),
cause = cause)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkException
import org.apache.spark.{ExecutorDeadException, SparkException}
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.{count, timestamp_seconds, window}
Expand Down Expand Up @@ -128,12 +128,14 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA
testQuietly("foreach with error") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]

val funcEx = new RuntimeException("ForeachSinkSuite error")
val query = input.toDS().repartition(1).writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.foreach(new TestForeachWriter() {
override def process(value: Int): Unit = {
super.process(value)
throw new RuntimeException("ForeachSinkSuite error")
throw funcEx
}
}).start()
input.addData(1, 2, 3, 4)
Expand All @@ -142,8 +144,13 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA
val e = intercept[StreamingQueryException] {
query.processAllAvailable()
}
assert(e.getCause.isInstanceOf[SparkException])
assert(e.getCause.getCause.getMessage === "ForeachSinkSuite error")

val errClass = "FOREACH_USER_FUNCTION_ERROR"

// verify that we classified the exception
assert(e.getMessage.contains(errClass))
assert(e.cause.asInstanceOf[RuntimeException].getMessage == funcEx.getMessage)

assert(query.isActive === false)

val allEvents = ForeachWriterSuite.allEvents()
Expand All @@ -157,6 +164,23 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA
assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error")
// 'close' shouldn't be called with abort message if close with error has been called
assert(allEvents(0).size == 3)

val sparkEx = ExecutorDeadException("network error")
val e2 = intercept[StreamingQueryException] {
val query2 = input.toDS().repartition(1).writeStream
.foreach(new TestForeachWriter() {
override def process(value: Int): Unit = {
super.process(value)
throw sparkEx
}
}).start()
query2.processAllAvailable()
}

// we didn't wrap the spark exception
assert(!e2.getMessage.contains(errClass))
assert(e2.getCause.getCause.asInstanceOf[ExecutorDeadException].getMessage
== sparkEx.getMessage)
}
}

Expand Down

0 comments on commit a6203cc

Please sign in to comment.