From cf24fad2da59ae338db677349a44bcefaf1adb16 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Mon, 16 Jan 2017 17:48:21 +0800 Subject: [PATCH 1/2] Abort StateStore on error --- .../streaming/StatefulAggregate.scala | 8 ++ .../state/HDFSBackedStateStoreProvider.scala | 2 +- .../streaming/state/StateStore.scala | 2 +- .../streaming/state/StateStoreConf.scala | 4 +- .../streaming/StreamingAggregationSuite.scala | 76 ++++++++++++++++++- 5 files changed, 86 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 0551e4b4a2ef5..d4ccced9ac9b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType +import org.apache.spark.TaskContext /** Used to identify the state store for a given operator. */ @@ -150,6 +151,13 @@ case class StateStoreSaveExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + // Abort the state store in case of error + TaskContext.get().addTaskCompletionListener(_ => { + if (!store.hasCommitted) { + store.abort() + } + }) + outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 4f3f8181d1f4e..1279b71c4d6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -203,7 +203,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Whether all updates have been committed */ - override private[state] def hasCommitted: Boolean = { + override private[streaming] def hasCommitted: Boolean = { state == COMMITTED } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 9bc6c0e2b9334..d59746f947c1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -83,7 +83,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[state] def hasCommitted: Boolean + private[streaming] def hasCommitted: Boolean } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index acfaa8e5eb3c4..66457e9f3aacb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +private[sql] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { def this() = this(new SQLConf) @@ -29,7 +29,7 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex val minVersionsToRetain = conf.minBatchesToRetain } -private[streaming] object StateStoreConf { +private[sql] object StateStoreConf { val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index eca2647dea52b..4aeed6df6b10c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -19,17 +19,26 @@ package org.apache.spark.sql.streaming import java.util.TimeZone +import scala.collection.mutable +import scala.reflect.runtime.{universe => ru} + +import org.apache.hadoop.conf.Configuration +import org.mockito.Mockito +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll +import org.scalatest.PrivateMethodTester._ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.sql.types._ object FailureSinglton { var firstTime = true @@ -335,4 +344,67 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { CheckLastBatch((90L, 1), (100L, 1), (105L, 1)) ) } + + test("abort StateStore in case of error") { + quietly { + val inputData = MemoryStream[Long] + val aggregated = + inputData.toDS() + .groupBy($"value") + .agg(count("*")) + var aborted = false + testStream(aggregated, Complete)( + // This whole `AssertOnQuery` is used to inject a mock state store + AssertOnQuery(execution => { + // (1) Use reflection to get `StateStore.loadedProviders` + val loadedProviders = { + val field = ru.typeOf[StateStore.type].decl(ru.TermName("loadedProviders")).asTerm + ru.runtimeMirror(StateStore.getClass.getClassLoader) + .reflect(StateStore) + .reflectField(field) + .get + .asInstanceOf[mutable.HashMap[StateStoreId, StateStoreProvider]] + } + // (2) Make a storeId + val storeId = { + val checkpointLocation = + execution invokePrivate PrivateMethod[String]('checkpointFile)("state") + StateStoreId(checkpointLocation, 0L, 0) + } + // (3) Make `mockStore` and `mockProvider` + val (mockStore, mockProvider) = { + val keySchema = StructType(Seq( + StructField("value", LongType, false))) + val valueSchema = StructType(Seq( + StructField("value", LongType, false), StructField("count", LongType, false))) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration + (Mockito.spy( + StateStore.get(storeId, keySchema, valueSchema, version = 0, storeConf, hadoopConf)), + Mockito.spy(loadedProviders.get(storeId).get)) + } + // (4) Setup `mockStore` and `mockProvider` + Mockito.doAnswer(new Answer[Long] { + override def answer(invocationOnMock: InvocationOnMock): Long = { + sys.error("injected error on commit()") + } + }).when(mockStore).commit() + Mockito.doAnswer(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + invocationOnMock.callRealMethod() + // Mark the flag for later check + aborted = true + } + }).when(mockStore).abort() + Mockito.doReturn(mockStore).when(mockProvider).getStore(version = 0) + // (5) Inject `mockProvider`, which later on would inject `mockStore` + loadedProviders.put(storeId, mockProvider) + true + }), // End of AssertOnQuery, i.e. end of injecting `mockStore` + AddData(inputData, 1L, 2L, 3L), + ExpectFailure[SparkException](), + AssertOnQuery { _ => aborted } // Check that `mockStore.abort()` is called upon error + ) + } + } } From 0f9e54d9efe4c9d7f446cb2f4dc46741cef776f7 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 18 Jan 2017 11:36:04 +0800 Subject: [PATCH 2/2] Remove test --- .../streaming/state/StateStoreConf.scala | 4 +- .../streaming/StreamingAggregationSuite.scala | 76 +------------------ 2 files changed, 4 insertions(+), 76 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 66457e9f3aacb..acfaa8e5eb3c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[sql] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { def this() = this(new SQLConf) @@ -29,7 +29,7 @@ private[sql] class StateStoreConf(@transient private val conf: SQLConf) extends val minVersionsToRetain = conf.minBatchesToRetain } -private[sql] object StateStoreConf { +private[streaming] object StateStoreConf { val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 4aeed6df6b10c..eca2647dea52b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -19,26 +19,17 @@ package org.apache.spark.sql.streaming import java.util.TimeZone -import scala.collection.mutable -import scala.reflect.runtime.{universe => ru} - -import org.apache.hadoop.conf.Configuration -import org.mockito.Mockito -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll -import org.scalatest.PrivateMethodTester._ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ -import org.apache.spark.sql.types._ object FailureSinglton { var firstTime = true @@ -344,67 +335,4 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { CheckLastBatch((90L, 1), (100L, 1), (105L, 1)) ) } - - test("abort StateStore in case of error") { - quietly { - val inputData = MemoryStream[Long] - val aggregated = - inputData.toDS() - .groupBy($"value") - .agg(count("*")) - var aborted = false - testStream(aggregated, Complete)( - // This whole `AssertOnQuery` is used to inject a mock state store - AssertOnQuery(execution => { - // (1) Use reflection to get `StateStore.loadedProviders` - val loadedProviders = { - val field = ru.typeOf[StateStore.type].decl(ru.TermName("loadedProviders")).asTerm - ru.runtimeMirror(StateStore.getClass.getClassLoader) - .reflect(StateStore) - .reflectField(field) - .get - .asInstanceOf[mutable.HashMap[StateStoreId, StateStoreProvider]] - } - // (2) Make a storeId - val storeId = { - val checkpointLocation = - execution invokePrivate PrivateMethod[String]('checkpointFile)("state") - StateStoreId(checkpointLocation, 0L, 0) - } - // (3) Make `mockStore` and `mockProvider` - val (mockStore, mockProvider) = { - val keySchema = StructType(Seq( - StructField("value", LongType, false))) - val valueSchema = StructType(Seq( - StructField("value", LongType, false), StructField("count", LongType, false))) - val storeConf = StateStoreConf.empty - val hadoopConf = new Configuration - (Mockito.spy( - StateStore.get(storeId, keySchema, valueSchema, version = 0, storeConf, hadoopConf)), - Mockito.spy(loadedProviders.get(storeId).get)) - } - // (4) Setup `mockStore` and `mockProvider` - Mockito.doAnswer(new Answer[Long] { - override def answer(invocationOnMock: InvocationOnMock): Long = { - sys.error("injected error on commit()") - } - }).when(mockStore).commit() - Mockito.doAnswer(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - invocationOnMock.callRealMethod() - // Mark the flag for later check - aborted = true - } - }).when(mockStore).abort() - Mockito.doReturn(mockStore).when(mockProvider).getStore(version = 0) - // (5) Inject `mockProvider`, which later on would inject `mockStore` - loadedProviders.put(storeId, mockProvider) - true - }), // End of AssertOnQuery, i.e. end of injecting `mockStore` - AddData(inputData, 1L, 2L, 3L), - ExpectFailure[SparkException](), - AssertOnQuery { _ => aborted } // Check that `mockStore.abort()` is called upon error - ) - } - } }