diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 0551e4b4a2ef5..d4ccced9ac9b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType +import org.apache.spark.TaskContext /** Used to identify the state store for a given operator. */ @@ -150,6 +151,13 @@ case class StateStoreSaveExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + // Abort the state store in case of error + TaskContext.get().addTaskCompletionListener(_ => { + if (!store.hasCommitted) { + store.abort() + } + }) + outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 4f3f8181d1f4e..1279b71c4d6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -203,7 +203,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Whether all updates have been committed */ - override private[state] def hasCommitted: Boolean = { + override private[streaming] def hasCommitted: Boolean = { state == COMMITTED } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 9bc6c0e2b9334..d59746f947c1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -83,7 +83,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[state] def hasCommitted: Boolean + private[streaming] def hasCommitted: Boolean }