diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index b42aae1311f43..469e247addf46 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1475,6 +1475,12 @@ ], "sqlState" : "39000" }, + "FOREACH_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in foreach sink. Reason: " + ], + "sqlState" : "39000" + }, "FOUND_MULTIPLE_DATA_SOURCES" : { "message" : [ "Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath." diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py index 5041fefff1909..b29338e7f59e7 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py @@ -220,9 +220,10 @@ def close(self, error): try: tester.run_streaming_query_on_writer(ForeachWriter(), 1) self.fail("bad writer did not fail the query") # this is not expected - except StreamingQueryException: - # TODO: Verify whether original error message is inside the exception - pass + except StreamingQueryException as e: + err_msg = str(e) + self.assertTrue("test error" in err_msg) + self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg) self.assertEqual(len(tester.process_events()), 0) # no row was processed close_events = tester.close_events() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 67b264436fea9..ed7ff6a753487 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -21,12 +21,15 @@ import java.io.File import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} +import scala.util.control.NonFatal + +import org.apache.spark.{JobArtifactSet, SparkEnv, SparkThrowable, TaskContext} import org.apache.spark.api.python._ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.sources.ForeachUserFuncException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, Utils} @@ -53,6 +56,8 @@ class WriterThread(outputIterator: Iterator[Array[Byte]]) } catch { // Cache exceptions seen while evaluating the Python function on the streamed input. The // parent thread will throw this crashed exception eventually. + case NonFatal(e) if !e.isInstanceOf[SparkThrowable] => + _exception = ForeachUserFuncException(e) case t: Throwable => _exception = t } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 6fd58e13366e0..d5aed9aea1820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} import org.apache.spark.sql.execution.command.StreamingExplainCommand -import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException +import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncException, ForeachUserFuncException} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.streaming._ @@ -346,9 +346,11 @@ abstract class StreamExecution( getLatestExecutionContext().updateStatusMessage("Stopped") case e: Throwable => val message = if (e.getMessage == null) "" else e.getMessage - val cause = if (e.isInstanceOf[ForeachBatchUserFuncException]) { + val cause = if (e.isInstanceOf[ForeachBatchUserFuncException] || + e.isInstanceOf[ForeachUserFuncException]) { // We want to maintain the current way users get the causing exception - // from the StreamingQueryException. Hence the ForeachBatch exception is unwrapped here. + // from the StreamingQueryException. + // Hence the ForeachBatch/Foreach exception is unwrapped here. e.getCause } else { e @@ -728,6 +730,7 @@ object StreamExecution { if e2.getCause != null => isInterruptionException(e2.getCause, sc) case fe: ForeachBatchUserFuncException => isInterruptionException(fe.getCause, sc) + case fes: ForeachUserFuncException => isInterruptionException(fes.getCause, sc) case se: SparkException => if (se.getCause == null) { isCancelledJobGroup(se.getMessage) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index bbbe28ec7ab11..c0956a62e59fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.execution.streaming.sources import java.util +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -146,6 +149,9 @@ class ForeachDataWriter[T]( try { writer.process(rowConverter(record)) } catch { + case NonFatal(e) if !e.isInstanceOf[SparkThrowable] => + errorOrNull = e + throw ForeachUserFuncException(e) case t: Throwable => errorOrNull = t throw t @@ -172,3 +178,12 @@ class ForeachDataWriter[T]( * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. */ case object ForeachWriterCommitMessage extends WriterCommitMessage + +/** + * Exception that wraps the exception thrown in the user provided function in Foreach sink. + */ +private[sql] case class ForeachUserFuncException(cause: Throwable) + extends SparkException( + errorClass = "FOREACH_USER_FUNCTION_ERROR", + messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")), + cause = cause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 324717d92c972..4cf82cebeb812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkException +import org.apache.spark.{ExecutorDeadException, SparkException} import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} @@ -128,12 +128,14 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA testQuietly("foreach with error") { withTempDir { checkpointDir => val input = MemoryStream[Int] + + val funcEx = new RuntimeException("ForeachSinkSuite error") val query = input.toDS().repartition(1).writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("ForeachSinkSuite error") + throw funcEx } }).start() input.addData(1, 2, 3, 4) @@ -142,8 +144,13 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val e = intercept[StreamingQueryException] { query.processAllAvailable() } - assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "ForeachSinkSuite error") + + val errClass = "FOREACH_USER_FUNCTION_ERROR" + + // verify that we classified the exception + assert(e.getMessage.contains(errClass)) + assert(e.cause.asInstanceOf[RuntimeException].getMessage == funcEx.getMessage) + assert(query.isActive === false) val allEvents = ForeachWriterSuite.allEvents() @@ -157,6 +164,23 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") // 'close' shouldn't be called with abort message if close with error has been called assert(allEvents(0).size == 3) + + val sparkEx = ExecutorDeadException("network error") + val e2 = intercept[StreamingQueryException] { + val query2 = input.toDS().repartition(1).writeStream + .foreach(new TestForeachWriter() { + override def process(value: Int): Unit = { + super.process(value) + throw sparkEx + } + }).start() + query2.processAllAvailable() + } + + // we didn't wrap the spark exception + assert(!e2.getMessage.contains(errClass)) + assert(e2.getCause.getCause.asInstanceOf[ExecutorDeadException].getMessage + == sparkEx.getMessage) } }