diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d59746f947c1e..e61d95a1b1bb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.util.concurrent.{ScheduledFuture, TimeUnit} +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.util.control.NonFatal @@ -124,12 +125,46 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 + @GuardedBy("loadedProviders") private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() - private val maintenanceTaskExecutor = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") - @volatile private var maintenanceTask: ScheduledFuture[_] = null - @volatile private var _coordRef: StateStoreCoordinatorRef = null + /** + * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` + * will be called when an exception happens. + */ + class MaintenanceTask(periodMs: Long, task: => Unit, onError: => Unit) { + private val executor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") + + private val runnable = new Runnable { + override def run(): Unit = { + try { + task + } catch { + case NonFatal(e) => + logWarning("Error running maintenance thread", e) + onError + throw e + } + } + } + + private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate( + runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + + def stop(): Unit = { + future.cancel(false) + executor.shutdown() + } + + def isRunning: Boolean = !future.isDone + } + + @GuardedBy("loadedProviders") + private var maintenanceTask: MaintenanceTask = null + + @GuardedBy("loadedProviders") + private var _coordRef: StateStoreCoordinatorRef = null /** Get or create a store associated with the id. */ def get( @@ -162,7 +197,7 @@ object StateStore extends Logging { } def isMaintenanceRunning: Boolean = loadedProviders.synchronized { - maintenanceTask != null + maintenanceTask != null && maintenanceTask.isRunning } /** Unload and stop all state store providers */ @@ -170,7 +205,7 @@ object StateStore extends Logging { loadedProviders.clear() _coordRef = null if (maintenanceTask != null) { - maintenanceTask.cancel(false) + maintenanceTask.stop() maintenanceTask = null } logInfo("StateStore stopped") @@ -179,14 +214,14 @@ object StateStore extends Logging { /** Start the periodic maintenance task if not already started and if Spark active */ private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { val env = SparkEnv.get - if (maintenanceTask == null && env != null) { + if (env != null && !isMaintenanceRunning) { val periodMs = env.conf.getTimeAsMs( MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") - val runnable = new Runnable { - override def run(): Unit = { doMaintenance() } - } - maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate( - runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + maintenanceTask = new MaintenanceTask( + periodMs, + task = { doMaintenance() }, + onError = { loadedProviders.synchronized { loadedProviders.clear() } } + ) logInfo("State Store maintenance task started") } } @@ -198,21 +233,20 @@ object StateStore extends Logging { private def doMaintenance(): Unit = { logDebug("Doing maintenance") if (SparkEnv.get == null) { - stop() - } else { - loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => - try { - if (verifyIfStoreInstanceActive(id)) { - provider.doMaintenance() - } else { - unload(id) - logInfo(s"Unloaded $provider") - } - } catch { - case NonFatal(e) => - logWarning(s"Error managing $provider, stopping management thread") - stop() + throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores") + } + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfStoreInstanceActive(id)) { + provider.doMaintenance() + } else { + unload(id) + logInfo(s"Unloaded $provider") } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider, stopping management thread") + throw e } } } @@ -238,7 +272,7 @@ object StateStore extends Logging { } } - private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { if (_coordRef == null) {