From 8d1cb76066e45bf24952124d3edc4357303067e5 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 16 Oct 2024 14:57:31 +0200 Subject: [PATCH 001/108] [SPARK-49987][SQL] Fix the error prompt when `seedExpression` is non-foldable in `randstr` ### What changes were proposed in this pull request? The pr aims to - fix the `error prompt` when `seedExpression` is `non-foldable` in `randstr`. - use `toSQLId` to set the parameter value `inputName` for `randstr ` and `uniform` of `NON_FOLDABLE_INPUT`. ### Why are the changes needed? - Let me take an example ```scala val df = Seq(1.1).toDF("a") df.createOrReplaceTempView("t") sql("SELECT randstr(1, a) from t").show(false) ``` - Before image ```shell [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "randstr(1, a)" due to data type mismatch: the input seedExpression should be a foldable INT or SMALLINT expression; however, got "a". SQLSTATE: 42K09; line 1 pos 7; 'Project [unresolvedalias(randstr(1, a#5, false))] +- SubqueryAlias t +- View (`t`, [a#5]) +- Project [value#1 AS a#5] +- LocalRelation [value#1] ``` - After ```shell [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "randstr(1, a)" due to data type mismatch: the input seed should be a foldable INT or SMALLINT expression; however, got "a". SQLSTATE: 42K09; line 1 pos 7; 'Project [unresolvedalias(randstr(1, a#5, false))] +- SubqueryAlias t +- View (`t`, [a#5]) +- Project [value#1 AS a#5] +- LocalRelation [value#1] ``` - The `parameter` name (`seedExpression`) in the error message does not match the `parameter` name (`seed`) seen in docs by the end-user. image ### Does this PR introduce _any_ user-facing change? Yes, When `seed` is `non-foldable `, the end-user will get a consistent experience in the error prompt. ### How was this patch tested? Update existed UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48490 from panbingkun/SPARK-49987. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../sql/catalyst/expressions/randomExpressions.scala | 8 ++++---- .../resources/sql-tests/analyzer-results/random.sql.out | 8 ++++---- .../src/test/resources/sql-tests/results/random.sql.out | 8 ++++---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 3cec83facd01d..16bdaa1f7f708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} +import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} @@ -263,7 +263,7 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { @@ -374,14 +374,14 @@ case class RandStr( var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "INT or SMALLINT" Seq((length, "length", 0), - (seedExpression, "seedExpression", 1)).foreach { + (seedExpression, "seed", 1)).foreach { case (expr: Expression, name: String, index: Int) => if (result == TypeCheckResult.TypeCheckSuccess) { if (!expr.foldable) { result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 133cd6a60a4fb..31919381c99b6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -188,7 +188,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -211,7 +211,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -436,7 +436,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -459,7 +459,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 0b4e5e078ee15..01638abdcec6e 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -240,7 +240,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -265,7 +265,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -520,7 +520,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -545,7 +545,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 47691e1ccd40f..39c839ae5a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -478,7 +478,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "length", + "inputName" -> "`length`", "inputType" -> "INT or SMALLINT", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"randstr(a, 10)\""), @@ -530,7 +530,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "min", + "inputName" -> "`min`", "inputType" -> "integer or floating-point", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"uniform(a, 10)\""), From a3b91247b32083805fdd50e9f7f46e9a91b8fd8d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 16 Oct 2024 07:22:59 -0700 Subject: [PATCH 002/108] [SPARK-49981][CORE][TESTS] Fix `AsyncRDDActionsSuite.FutureAction result, timeout` test case to be robust ### What changes were proposed in this pull request? This PR aims to fix `AsyncRDDActionsSuite.FutureAction result, timeout` test case to be robust. ### Why are the changes needed? To reduce the flakiness in GitHub Action CI. Previously, the sleep time is identical to the timeout time. It causes a flakiness in some environments like GitHub Action. - https://github.com/apache/spark/actions/runs/11298639789/job/31428018075 ``` AsyncRDDActionsSuite: ... - FutureAction result, timeout *** FAILED *** Expected exception java.util.concurrent.TimeoutException to be thrown, but no exception was thrown (AsyncRDDActionsSuite.scala:206) ``` ### Does this PR introduce _any_ user-facing change? No, this is a test-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48485 from dongjoon-hyun/SPARK-49981. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 4239180ba6c37..fb2bb83cb7fc4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -201,10 +201,10 @@ class AsyncRDDActionsSuite extends SparkFunSuite with TimeLimits { test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) - .mapPartitions(itr => { Thread.sleep(20); itr }) + .mapPartitions(itr => { Thread.sleep(200); itr }) .countAsync() intercept[TimeoutException] { - ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) + ThreadUtils.awaitResult(f, Duration(2, "milliseconds")) } } From bcfe62b9988f9b00c23de0b71acc1c6170edee9e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 16 Oct 2024 07:24:33 -0700 Subject: [PATCH 003/108] [SPARK-49983][CORE][TESTS] Fix `BarrierTaskContextSuite.successively sync with allGather and barrier` test case to be robust ### What changes were proposed in this pull request? This PR aims to fix `BarrierTaskContextSuite.successively sync with allGather and barrier` test case to be robust. ### Why are the changes needed? The test case asserts the duration of partitions. However, this is flaky because we don't know when a partition is triggered before `barrier` sync. https://github.com/apache/spark/blob/0e75d19a736aa18fe77414991ebb7e3577a43af8/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala#L116-L118 Although we added `TestUtils.waitUntilExecutorsUp` at Apache Spark 3.0.0 like the following, - #28658 let's say a partition starts slowly than `38ms` and all partitions sleep `1s` exactly. Then, the test case fails like the following. - https://github.com/apache/spark/actions/runs/11298639789/job/31428018075 ``` BarrierTaskContextSuite: ... - successively sync with allGather and barrier *** FAILED *** 1038 was not less than or equal to 1000 (BarrierTaskContextSuite.scala:118) ``` According to the failure history here (SPARK-49983) and SPARK-31730, the slowness seems to be less than `200ms` when it happens. So, this PR aims to reduce the flakiness by capping the sleep up to 500ms while keeping the `1s` validation. There is no test coverage change because this test case focuses on the `successively sync with allGather and battier`. ### Does this PR introduce _any_ user-facing change? No, this is a test-only test case. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48487 from dongjoon-hyun/SPARK-49983. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/scheduler/BarrierTaskContextSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 849832c57edaa..f00fb0d2cfa3f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -101,7 +101,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) + Thread.sleep(Random.nextInt(500)) context.barrier() val time1 = System.currentTimeMillis() // Sleep for a random time before global sync. From 60200ae195a124003cf77d4ab3872f1652b6b9c7 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 16 Oct 2024 18:48:51 +0200 Subject: [PATCH 004/108] [SPARK-49957][SQL] Scala API for string validation functions ### What changes were proposed in this pull request? Adding the Scala API for the 4 new string validation expressions: - is_valid_utf8 - make_valid_utf8 - validate_utf8 - try_validate_utf8 ### Why are the changes needed? Offer a complete Scala API for the new expressions in Spark 4.0. ### Does this PR introduce _any_ user-facing change? Yes, adding Scala API for the 4 new Spark expressions. ### How was this patch tested? New tests for the Scala API. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48454 from uros-db/api-validation. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- python/pyspark/sql/tests/test_functions.py | 4 +- .../org/apache/spark/sql/functions.scala | 38 +++++++++++++++++++ .../spark/sql/StringFunctionsSuite.scala | 38 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a51156e895c62..f6c1278c0dc7a 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -83,7 +83,9 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set() + expected_missing_in_py = set( + ["is_valid_utf8", "make_valid_utf8", "validate_utf8", "try_validate_utf8"] + ) self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 4838bc5298bb3..4a9a20efd3a56 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -3911,6 +3911,44 @@ object functions { def encode(value: Column, charset: String): Column = Column.fn("encode", value, lit(charset)) + /** + * Returns true if the input is a valid UTF-8 string, otherwise returns false. + * + * @group string_funcs + * @since 4.0.0 + */ + def is_valid_utf8(str: Column): Column = + Column.fn("is_valid_utf8", str) + + /** + * Returns a new string in which all invalid UTF-8 byte sequences, if any, are replaced by the + * Unicode replacement character (U+FFFD). + * + * @group string_funcs + * @since 4.0.0 + */ + def make_valid_utf8(str: Column): Column = + Column.fn("make_valid_utf8", str) + + /** + * Returns the input value if it corresponds to a valid UTF-8 string, or emits a + * SparkIllegalArgumentException exception otherwise. + * + * @group string_funcs + * @since 4.0.0 + */ + def validate_utf8(str: Column): Column = + Column.fn("validate_utf8", str) + + /** + * Returns the input value if it corresponds to a valid UTF-8 string, or NULL otherwise. + * + * @group string_funcs + * @since 4.0.0 + */ + def try_validate_utf8(str: Column): Column = + Column.fn("try_validate_utf8", str) + /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with * HALF_EVEN round mode, and returns the result as a string column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ec240d71b851f..c94f57a11426a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -352,6 +352,44 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { // scalastyle:on } + test("UTF-8 string is valid") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(is_valid_utf8($"a")), Row(true)) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(is_valid_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(is_valid_utf8($"a")), Row(false)) + // scalastyle:on + } + + test("UTF-8 string make valid") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(make_valid_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(make_valid_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(make_valid_utf8($"a")), Row("\uFFFD")) + // scalastyle:on + } + + test("UTF-8 string validate") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(validate_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(validate_utf8($"b")), Row(null)) + checkError( + exception = intercept[SparkIllegalArgumentException] { + Seq(Array[Byte](-1)).toDF("a").select(validate_utf8($"a")).collect() + }, + condition = "INVALID_UTF8_STRING", + parameters = Map("str" -> "\\xFF") + ) + // scalastyle:on + } + + test("UTF-8 string try validate") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(try_validate_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(try_validate_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(try_validate_utf8($"a")), Row(null)) + // scalastyle:on + } + test("string translate") { val df = Seq(("translate", "")).toDF("a", "b") checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) From f860af67db34c9ae68076a867d4d61caf574cbb8 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Thu, 17 Oct 2024 01:23:16 +0800 Subject: [PATCH 005/108] [SPARK-48155][FOLLOWUP][SQL] AQEPropagateEmptyRelation for left anti join should check if remain child is just BroadcastQueryStageExec ### What changes were proposed in this pull request? As title. ### Why are the changes needed? We encountered BroadcastNestedLoopJoin LeftAnti BuildLeft, and it's right is empty. It is left child of left outer BroadcastHashJoin. The case is more complicated, part of the Initial Plan is as follows ``` :- Project (214) : +- BroadcastHashJoin LeftOuter BuildRight (213) : :- BroadcastNestedLoopJoin LeftAnti BuildLeft (211) : : :- BroadcastExchange (187) : : : +- Project (186) : : : +- Filter (185) : : : +- Scan parquet (31) : : +- LocalLimit (210) : : +- Project (209) : : +- BroadcastHashJoin Inner BuildLeft (208) : : :- BroadcastExchange (194) : : : +- Project (193) : : : +- BroadcastHashJoin LeftOuter BuildRight (192) : : : :- Project (189) : : : : +- Filter (188) : : : : +- Scan parquet (37) : : : +- BroadcastExchange (191) ``` After AQEPropagateEmptyRelation, report an error "HashJoin should not take LeftOuter as the JoinType with building left side" ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48300 from zml1206/SPARK-48155-followup. Authored-by: zml1206 Signed-off-by: Wenchen Fan --- .../optimizer/PropagateEmptyRelation.scala | 3 +- .../adaptive/AdaptiveQueryExecSuite.scala | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 832af340c3397..d23d43acc217b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -111,7 +111,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) case LeftSemi if isRightEmpty | isFalseCondition => empty(p) - case LeftAnti if isRightEmpty | isFalseCondition => p.left + case LeftAnti if (isRightEmpty | isFalseCondition) && canExecuteWithoutJoin(p.left) => + p.left case FullOuter if isLeftEmpty && isRightEmpty => empty(p) case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c5e64c96b2c8a..4bf993f82495b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2829,6 +2829,38 @@ class AdaptiveQueryExecSuite assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) assert(findTopLevelUnion(adaptivePlan).size == 0) } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF().createOrReplaceTempView("t1") + spark.range(100).createOrReplaceTempView("t2") + spark.range(2).createOrReplaceTempView("t3") + spark.range(2).createOrReplaceTempView("t4") + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT tt2.value + |FROM ( + | SELECT value + | FROM t1 + | WHERE NOT EXISTS ( + | SELECT 1 + | FROM ( + | SELECT t2.id + | FROM t2 + | JOIN t3 ON t2.id = t3.id + | AND t2.id > 100 + | ) tt + | WHERE t1.value = tt.id + | ) + | AND t1.value = 1 + |) tt2 + | LEFT JOIN t4 ON tt2.value = t4.id + |""".stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + } + } } test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { From 31a411773a3e97adb833289f8c695b37802cfedb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 16 Oct 2024 11:27:35 -0700 Subject: [PATCH 006/108] [SPARK-49057][SQL][TESTS][FOLLOWUP] Handle `_LEGACY_ERROR_TEMP_2235` error case ### What changes were proposed in this pull request? This PR aims to fix a flaky test by handling `_LEGACY_ERROR_TEMP_2235`(multiple failures exception) in addition to the single exception. ### Why are the changes needed? After merging - #47533 The following failures were reported multiple times in the PR and today. - https://github.com/apache/spark/actions/runs/11358629880/job/31593568476 - https://github.com/apache/spark/actions/runs/11367718498/job/31621128680 - https://github.com/apache/spark/actions/runs/11360602982/job/31598792247 ``` [info] - SPARK-47148: AQE should avoid to submit shuffle job on cancellation *** FAILED *** (6 seconds, 92 milliseconds) [info] "Multiple failures in stage materialization." did not contain "coalesce test error" (AdaptiveQueryExecSuite.scala:939) ``` The root cause is that `AdaptiveSparkPlanExec.cleanUpAndThrowException` throws two types of exceptions. When there are multiple errors, `_LEGACY_ERROR_TEMP_2235` is thrown. We need to handle this too in the test case. https://github.com/apache/spark/blob/bcfe62b9988f9b00c23de0b71acc1c6170edee9e/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala#L843-L850 https://github.com/apache/spark/blob/bcfe62b9988f9b00c23de0b71acc1c6170edee9e/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala#L1916-L1921 ### Does this PR introduce _any_ user-facing change? No, this is a test-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48498 from dongjoon-hyun/SPARK-49057. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 4bf993f82495b..8e9ba6c8e21d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -936,7 +936,8 @@ class AdaptiveQueryExecSuite val error = intercept[SparkException] { joined.collect() } - assert(error.getMessage() contains "coalesce test error") + assert((Seq(error) ++ Option(error.getCause) ++ error.getSuppressed()).exists( + e => e.getMessage() != null && e.getMessage().contains("coalesce test error"))) val adaptivePlan = joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] From f5e6b05e486efb3d67fd06166ca8f103efb750dc Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 16 Oct 2024 23:24:39 +0200 Subject: [PATCH 007/108] [SPARK-49643][SQL] Merge _LEGACY_ERROR_TEMP_2042 into ARITHMETIC_OVERFLOW ### What changes were proposed in this pull request? Merging related legacy error to its proper class. ### Why are the changes needed? We want to get remove legacy errors, as they are not properly migrated to the new system of errors. Also, [PR](https://github.com/apache/spark/pull/48206/files#diff-0ffd087e0d4e1618761a42c91b8712fd469e758f4789ca2fafdefff753fe81d5) started getting to big, so this is an effort to split the change needed. ### Does this PR introduce _any_ user-facing change? Yes, legacy error is now merged into ARITHMETIC_OVERFLOW. ### How was this patch tested? Existing tests check that the error message stayed the same. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48496 from mihailom-db/error2042. Authored-by: Mihailo Milosevic Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 5 ----- .../sql/catalyst/expressions/intervalExpressions.scala | 2 +- .../apache/spark/sql/errors/QueryExecutionErrors.scala | 10 ---------- .../expressions/IntervalExpressionsSuite.scala | 2 +- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 502558c21faa9..fdc00549cc088 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6864,11 +6864,6 @@ " is not implemented." ] }, - "_LEGACY_ERROR_TEMP_2042" : { - "message" : [ - ". If necessary set to false to bypass this error." - ] - }, "_LEGACY_ERROR_TEMP_2045" : { "message" : [ "Unsupported table change: " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 13676733a9bad..d18630f542020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -336,7 +336,7 @@ case class MakeInterval( val iu = IntervalUtils.getClass.getName.stripSuffix("$") val secFrac = sec.getOrElse("0") val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.arithmeticOverflowError(e);" + """throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", null);""" } else { s"${ev.isNull} = true;" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ebcc98a3af27a..edc1b909292df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -599,16 +599,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("methodName" -> methodName)) } - def arithmeticOverflowError(e: ArithmeticException): SparkArithmeticException = { - new SparkArithmeticException( - errorClass = "_LEGACY_ERROR_TEMP_2042", - messageParameters = Map( - "message" -> e.getMessage, - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = Array.empty, - summary = "") - } - def binaryArithmeticCauseOverflowError( eval1: Short, symbol: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 7caf23490a0ce..78bc77b9dc2ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -266,7 +266,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks), Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, Decimal.MAX_LONG_DIGITS, 6))) - checkExceptionInExpression[ArithmeticException](intervalExpr, EmptyRow, "") + checkExceptionInExpression[ArithmeticException](intervalExpr, EmptyRow, "ARITHMETIC_OVERFLOW") } withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { From e92bf3746ebf52028e2bc2168583bf9e1f463434 Mon Sep 17 00:00:00 2001 From: Tinglong Liao Date: Thu, 17 Oct 2024 08:13:07 +0900 Subject: [PATCH 008/108] [SPARK-49978][R] Move sparkR deprecation warning to package attach time ### What changes were proposed in this pull request? Previously, the output deprecation warning happens in the `spark.session` function, in this PR, we move it to the `.onAttach` function so it will be triggered whenever library is attached ### Why are the changes needed? I believe having the warning message on attach time have the following benefits: - **Have a more prompt warning.** If the deprecation is for the whole package instead of just the `sparkR.session` function, it is more intuitive for the warning to show up on attach time instead of waiting til later time - **Do not rely on the assumption of "every sparkR user will run sparkR.session method".** This asumption may not hold true all the time. For example, some hosted spark platform like Databricks already configure the spark session in the background and therefore will not show the error message. So making this change should make sure a broader reach for this warning notification - **Less intrusive warning**. Previous warning show up every time `sparkR.session` is called, but the new warning message will only show up once even if user run multiple `library`/`require` commands ### Does this PR introduce _any_ user-facing change? **Yes** 1. No more waring message in sparkR.session method 2. Warning message on library attach (when calling `library`/`require` function) image 3. Able to surpress warning by setting `SPARKR_SUPPRESS_DEPRECATION_WARNING` image ### How was this patch tested? Just a simple migration change, will rely on existing pre/post-merge check, and this existing test Also did manual testing(see previous section for screenshot) ### Was this patch authored or co-authored using generative AI tooling? No Closes #48482 from tinglongliao-db/sparkR-deprecation-migration. Authored-by: Tinglong Liao Signed-off-by: Hyukjin Kwon --- R/pkg/DESCRIPTION | 1 + R/pkg/R/sparkR.R | 6 ------ R/pkg/R/zzz.R | 30 ++++++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) create mode 100644 R/pkg/R/zzz.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f7dd261c10fd2..49000c62d1063 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -57,6 +57,7 @@ Collate: 'types.R' 'utils.R' 'window.R' + 'zzz.R' RoxygenNote: 7.1.2 VignetteBuilder: knitr NeedsCompilation: no diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 29c05b0db7c2d..1b5faad376eaa 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -403,12 +403,6 @@ sparkR.session <- function( sparkPackages = "", enableHiveSupport = TRUE, ...) { - - if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { - warning( - "SparkR is deprecated from Apache Spark 4.0.0 and will be removed in a future version.") - } - sparkConfigMap <- convertNamedListToEnv(sparkConfig) namedParams <- list(...) if (length(namedParams) > 0) { diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..947bd543b75e0 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# zzz.R - package startup message + +.onAttach <- function(...) { + if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { + packageStartupMessage( + paste0( + "Warning: ", + "SparkR is deprecated in Apache Spark 4.0.0 and will be removed in a future release. ", + "To continue using Spark in R, we recommend using sparklyr instead: ", + "https://spark.posit.co/get-started/" + ) + ) + } +} From 224d3ba1a2cde664fb94a96a4af1defac9ea401c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 10:37:05 +0900 Subject: [PATCH 009/108] [SPARK-49986][INFRA] Restore `scipy` installation in dockerfile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Restore `scipy` installation in dockerfile ### Why are the changes needed? https://docs.scipy.org/doc/scipy-1.13.1/building/index.html#system-level-dependencies > If you want to use the system Python and pip, you will need: C, C++, and Fortran compilers (typically gcc, g++, and gfortran). ... `scipy` actually depends on `gfortran`, but `apt-get remove --purge -y 'gfortran-11'` broke this dependency. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually check with the first commit https://github.com/apache/spark/pull/48489/commits/5be0dfa2431653c00c430424867dcc3918078226: move `apt-get remove --purge -y 'gfortran-11'` ahead of `scipy` installation, then the installation fails with ``` #18 394.3 Collecting scipy #18 394.4 Downloading scipy-1.13.1.tar.gz (57.2 MB) #18 395.2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.2/57.2 MB 76.7 MB/s eta 0:00:00 #18 401.3 Installing build dependencies: started #18 410.5 Installing build dependencies: finished with status 'done' #18 410.5 Getting requirements to build wheel: started #18 410.7 Getting requirements to build wheel: finished with status 'done' #18 410.7 Installing backend dependencies: started #18 411.8 Installing backend dependencies: finished with status 'done' #18 411.8 Preparing metadata (pyproject.toml): started #18 414.9 Preparing metadata (pyproject.toml): finished with status 'error' #18 414.9 error: subprocess-exited-with-error #18 414.9 #18 414.9 × Preparing metadata (pyproject.toml) did not run successfully. #18 414.9 │ exit code: 1 #18 414.9 ╰─> [42 lines of output] #18 414.9 + meson setup /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d/.mesonpy-xqfvs4ek -Dbuildtype=release -Db_ndebug=if-release -Db_vscrt=md --native-file=/tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d/.mesonpy-xqfvs4ek/meson-python-native-file.ini #18 414.9 The Meson build system #18 414.9 Version: 1.5.2 #18 414.9 Source dir: /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d #18 414.9 Build dir: /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d/.mesonpy-xqfvs4ek #18 414.9 Build type: native build #18 414.9 Project name: scipy #18 414.9 Project version: 1.13.1 #18 414.9 C compiler for the host machine: cc (gcc 11.4.0 "cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0") #18 414.9 C linker for the host machine: cc ld.bfd 2.38 #18 414.9 C++ compiler for the host machine: c++ (gcc 11.4.0 "c++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0") #18 414.9 C++ linker for the host machine: c++ ld.bfd 2.38 #18 414.9 Cython compiler for the host machine: cython (cython 3.0.11) #18 414.9 Host machine cpu family: x86_64 #18 414.9 Host machine cpu: x86_64 #18 414.9 Program python found: YES (/usr/local/bin/pypy3) #18 414.9 Run-time dependency python found: YES 3.9 #18 414.9 Program cython found: YES (/tmp/pip-build-env-v_vnvt3h/overlay/bin/cython) #18 414.9 Compiler for C supports arguments -Wno-unused-but-set-variable: YES #18 414.9 Compiler for C supports arguments -Wno-unused-function: YES #18 414.9 Compiler for C supports arguments -Wno-conversion: YES #18 414.9 Compiler for C supports arguments -Wno-misleading-indentation: YES #18 414.9 Library m found: YES #18 414.9 #18 414.9 ../meson.build:78:0: ERROR: Unknown compiler(s): [['gfortran'], ['flang'], ['nvfortran'], ['pgfortran'], ['ifort'], ['ifx'], ['g95']] #18 414.9 The following exception(s) were encountered: #18 414.9 Running `gfortran --version` gave "[Errno 2] No such file or directory: 'gfortran'" #18 414.9 Running `gfortran -V` gave "[Errno 2] No such file or directory: 'gfortran'" #18 414.9 Running `flang --version` gave "[Errno 2] No such file or directory: 'flang'" #18 414.9 Running `flang -V` gave "[Errno 2] No such file or directory: 'flang'" #18 414.9 Running `nvfortran --version` gave "[Errno 2] No such file or directory: 'nvfortran'" #18 414.9 Running `nvfortran -V` gave "[Errno 2] No such file or directory: 'nvfortran'" #18 414.9 Running `pgfortran --version` gave "[Errno 2] No such file or directory: 'pgfortran'" #18 414.9 Running `pgfortran -V` gave "[Errno 2] No such file or directory: 'pgfortran'" #18 414.9 Running `ifort --version` gave "[Errno 2] No such file or directory: 'ifort'" #18 414.9 Running `ifort -V` gave "[Errno 2] No such file or directory: 'ifort'" #18 414.9 Running `ifx --version` gave "[Errno 2] No such file or directory: 'ifx'" #18 414.9 Running `ifx -V` gave "[Errno 2] No such file or directory: 'ifx'" #18 414.9 Running `g95 --version` gave "[Errno 2] No such file or directory: 'g95'" #18 414.9 Running `g95 -V` gave "[Errno 2] No such file or directory: 'g95'" #18 414.9 #18 414.9 A full log can be found at /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984[4155](https://github.com/zhengruifeng/spark/actions/runs/11357130578/job/31589506939#step:7:4161)33ae9079d/.mesonpy-xqfvs4ek/meson-logs/meson-log.txt #18 414.9 [end of output] ``` see https://github.com/zhengruifeng/spark/actions/runs/11357130578/job/31589506939 ### Was this patch authored or co-authored using generative AI tooling? no Closes #48489 from zhengruifeng/infra_scipy. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- dev/infra/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 10a39497c8ed9..1edeed775880b 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -152,6 +152,6 @@ RUN python3.13 -m pip install lxml numpy>=2.1 && \ python3.13 -m pip cache purge # Remove unused installation packages to free up disk space -RUN apt-get remove --purge -y 'gfortran-11' 'humanity-icon-theme' 'nodejs-doc' || true +RUN apt-get remove --purge -y 'humanity-icon-theme' 'nodejs-doc' RUN apt-get autoremove --purge -y RUN apt-get clean From baa5f408a0985d703b4a1e4c5490c77b239180c4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 11:29:00 +0900 Subject: [PATCH 010/108] [SPARK-49945][PS][CONNECT] Add alias for `distributed_id` ### What changes were proposed in this pull request? 1, make `registerInternalExpression` support alias; 2, add alias `distributed_id` for `MonotonicallyIncreasingID` (rename `distributed_index` to `distributed_id` to be more consistent with existing `distributed_sequence_id`); 3, remove `distributedIndex` from `PythonSQLUtils` ### Why are the changes needed? make PS on Connect more consistent with Classic: ```py In [9]: ps.set_option("compute.default_index_type", "distributed") In [10]: spark_frame = ps.range(10).to_spark() In [11]: InternalFrame.attach_default_index(spark_frame).explain(True) ``` before: ![image](https://github.com/user-attachments/assets/6ce1fb5f-a3c6-42d5-a21e-3925207cb4d0) ``` == Parsed Logical Plan == 'Project ['monotonically_increasing_id() AS __index_level_0__#27, 'id] +- 'Project ['id] +- Project [__index_level_0__#19L, id#16L, monotonically_increasing_id() AS __natural_order__#22L] +- Project [monotonically_increasing_id() AS __index_level_0__#19L, id#16L] +- Range (0, 10, step=1, splits=Some(12)) ... ``` after: ![image](https://github.com/user-attachments/assets/00d3a8a1-251c-4cee-851e-c10f294d5248) ``` == Parsed Logical Plan == 'Project ['distributed_id() AS __index_level_0__#65, *] +- 'Project ['id] +- Project [__index_level_0__#45L, id#42L, monotonically_increasing_id() AS __natural_order__#48L] +- Project [distributed_id() AS __index_level_0__#45L, id#42L] +- Range (0, 10, step=1, splits=Some(12)) ... ``` ### Does this PR introduce _any_ user-facing change? spark ui ### How was this patch tested? existing test and manually check ### Was this patch authored or co-authored using generative AI tooling? no Closes #48439 from zhengruifeng/distributed_index. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/internal.py | 9 +-------- python/pyspark/pandas/spark/functions.py | 4 ++++ .../catalyst/analysis/FunctionRegistry.scala | 18 +++++++++++++++--- .../spark/sql/api/python/PythonSQLUtils.scala | 6 ------ 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 6063641e22e3b..90c361547b814 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -909,14 +909,7 @@ def attach_sequence_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDa @staticmethod def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: - scols = [scol_for(sdf, column) for column in sdf.columns] - # Does not add an alias to avoid having some changes in protobuf definition for now. - # The alias is more for query strings in DataFrame.explain, and they are cosmetic changes. - if is_remote(): - return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols) - jvm = sdf.sparkSession._jvm - jcol = jvm.PythonSQLUtils.distributedIndex() - return sdf.select(PySparkColumn(jcol).alias(column_name), *scols) + return sdf.select(SF.distributed_id().alias(column_name), "*") @staticmethod def attach_distributed_sequence_column( diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index bdd11559df3b6..53146a163b1ef 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -79,6 +79,10 @@ def null_index(col: Column) -> Column: return _invoke_internal_function_over_columns("null_index", col) +def distributed_id() -> Column: + return _invoke_internal_function_over_columns("distributed_id") + + def distributed_sequence_id() -> Column: return _invoke_internal_function_over_columns("distributed_sequence_id") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d03d8114e9976..abe61619a2331 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -895,9 +895,20 @@ object FunctionRegistry { /** Registry for internal functions used by Connect and the Column API. */ private[sql] val internal: SimpleFunctionRegistry = new SimpleFunctionRegistry - private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = { - val (info, builder) = FunctionRegistryBase.build(name, None) - internal.internalRegisterFunction(FunctionIdentifier(name), info, builder) + private def registerInternalExpression[T <: Expression : ClassTag]( + name: String, + setAlias: Boolean = false): Unit = { + val (info, builder) = FunctionRegistryBase.build[T](name, None) + val newBuilder = if (setAlias) { + (expressions: Seq[Expression]) => { + val expr = builder(expressions) + expr.setTagValue(FUNC_ALIAS, name) + expr + } + } else { + builder + } + internal.internalRegisterFunction(FunctionIdentifier(name), info, newBuilder) } registerInternalExpression[Product]("product") @@ -911,6 +922,7 @@ object FunctionRegistry { registerInternalExpression[Days]("days") registerInternalExpression[Hours]("hours") registerInternalExpression[UnwrapUDT]("unwrap_udt") + registerInternalExpression[MonotonicallyIncreasingID]("distributed_id", setAlias = true) registerInternalExpression[DistributedSequenceID]("distributed_sequence_id") registerInternalExpression[PandasProduct]("pandas_product") registerInternalExpression[PandasStddev]("pandas_stddev") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 08395ef4c347c..a66a6e54a7c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -154,12 +154,6 @@ private[sql] object PythonSQLUtils extends Logging { def namedArgumentExpression(name: String, e: Column): Column = NamedArgumentExpression(name, e) - def distributedIndex(): Column = { - val expr = MonotonicallyIncreasingID() - expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") - expr - } - @scala.annotation.varargs def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) From 9af705d27cae1ce9918f0467ecff6da10b311ab6 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 11:32:53 +0900 Subject: [PATCH 011/108] [SPARK-49951][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_(1099|3085) ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for _LEGACY_ERROR_TEMP_(1099|3085) ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48449 from itholic/SPARK-49951. Authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- .../resources/error/error-conditions.json | 16 ++++++--------- .../spark/sql/avro/AvroDataToCatalyst.scala | 20 +++++++------------ .../spark/sql/avro/AvroFunctionsSuite.scala | 11 ++++++++++ .../sql/errors/QueryCompilationErrors.scala | 10 ++++------ .../expressions/CsvExpressionsSuite.scala | 17 ++++++++++------ .../apache/spark/sql/CsvFunctionsSuite.scala | 8 +++----- .../apache/spark/sql/JsonFunctionsSuite.scala | 8 +++----- .../execution/datasources/xml/XmlSuite.scala | 8 +++----- 8 files changed, 48 insertions(+), 50 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index fdc00549cc088..3e4848658f14a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3827,6 +3827,12 @@ ], "sqlState" : "42617" }, + "PARSE_MODE_UNSUPPORTED" : { + "message" : [ + "The function doesn't support the mode. Acceptable modes are PERMISSIVE and FAILFAST." + ], + "sqlState" : "42601" + }, "PARSE_SYNTAX_ERROR" : { "message" : [ "Syntax error at or near ." @@ -6045,11 +6051,6 @@ "DataType '' is not supported by ." ] }, - "_LEGACY_ERROR_TEMP_1099" : { - "message" : [ - "() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_1103" : { "message" : [ "Unsupported component type in arrays." @@ -8096,11 +8097,6 @@ "No handler for UDF/UDAF/UDTF '': " ] }, - "_LEGACY_ERROR_TEMP_3085" : { - "message" : [ - "from_avro() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_3086" : { "message" : [ "Cannot persist into Hive metastore as table property keys may not start with 'spark.sql.': " diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 0b85b208242cb..9c8b2d0375588 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -24,10 +24,10 @@ import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ private[sql] case class AvroDataToCatalyst( @@ -80,12 +80,9 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val parseMode: ParseMode = { val mode = avroOptions.parseMode if (mode != PermissiveMode && mode != FailFastMode) { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, mode + ) } mode } @@ -123,12 +120,9 @@ private[sql] case class AvroDataToCatalyst( s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) case _ => - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> parseMode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, parseMode + ) } } } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index a7f7abadcf485..096cdfe0b9ee4 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -106,6 +106,17 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { functions.from_avro( $"avro", avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), expected) + + checkError( + exception = intercept[AnalysisException] { + avroStructDF.select( + functions.from_avro( + $"avro", avroTypeStruct, Map("mode" -> "DROPMALFORMED").asJava)).collect() + }, + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_avro`", + "mode" -> "DROPMALFORMED")) } test("roundtrip in to_avro and from_avro - array with null") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 9dc15c4a1b78d..431983214c482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InputParameter, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ParseMode} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction} @@ -1341,12 +1341,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def parseModeUnsupportedError(funcName: String, mode: ParseMode): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1099", + errorClass = "PARSE_MODE_UNSUPPORTED", messageParameters = Map( - "funcName" -> funcName, - "mode" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + "funcName" -> toSQLId(funcName), + "mode" -> mode.name)) } def nonFoldableArgumentError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index a89cb58c3e03b..249975f9c0d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -149,12 +149,17 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P test("unsupported mode") { val csvData = "---" val schema = StructType(StructField("a", DoubleType) :: Nil) - val exception = intercept[TestFailedException] { - checkEvaluation( - CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), UTC_OPT), - InternalRow(null)) - }.getCause - assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + + checkError( + exception = intercept[TestFailedException] { + checkEvaluation( + CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), UTC_OPT), + InternalRow(null)) + }.getCause.asInstanceOf[AnalysisException], + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_csv`", + "mode" -> "DROPMALFORMED")) } test("infer schema of CSV strings") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index e6907b8656482..970ed5843b3c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -352,12 +352,10 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_csv($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "from_csv", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> "FAILFAST")) + "funcName" -> "`from_csv`", + "mode" -> "DROPMALFORMED")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7b19ad988d308..84408d8e2495d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -861,12 +861,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "from_json", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> "FAILFAST")) + "funcName" -> "`from_json`", + "mode" -> "DROPMALFORMED")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 91f21c4a2ed34..059e4aadef2bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1315,12 +1315,10 @@ class XmlSuite spark.sql(s"""SELECT schema_of_xml('1', map('mode', 'DROPMALFORMED'))""") .collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "schema_of_xml", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> FailFastMode.name) + "funcName" -> "`schema_of_xml`", + "mode" -> "DROPMALFORMED") ) } From 948aeba93e1a5898ec4f8e71ff4eb89e7514c43f Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 17 Oct 2024 11:33:01 +0900 Subject: [PATCH 012/108] [SPARK-49947][SQL][TESTS] Upgrade `MsSql` docker image version ### What changes were proposed in this pull request? The pr aims to upgrade the `MsSql` docker image version from `2022-CU14-ubuntu-22.04` to `2022-CU15-ubuntu-22.04`. ### Why are the changes needed? This will help Apache Spark test the latest `MsSql`. https://hub.docker.com/r/microsoft/mssql-server image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48444 from panbingkun/SPARK-49947. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala | 2 +- .../apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala | 4 ++-- .../spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 4 ++-- .../apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala index 9d3c7d1eca328..6bd33356cab3d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc class MsSQLServerDatabaseOnDocker extends DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04") + "mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04") override val env = Map( "SA_PASSWORD" -> "Sapass123", "ACCEPT_EULA" -> "Y" diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 90cd68e6e1d24..62f088ebc2b6d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.{BinaryType, DecimalType} import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index aaaaa28558342..d884ad4c62466 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MsSqlServerIntegrationSuite" * }}} */ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index 9fb3bc4fba945..724c394a4f052 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" * }}} */ From 070f2bdfb968c8080de1c6614c1def978df823d4 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Thu, 17 Oct 2024 11:35:14 +0900 Subject: [PATCH 013/108] [SPARK-49876][CONNECT] Get rid of global locks from Spark Connect Service ### What changes were proposed in this pull request? Get rid of global locks from Spark Connect Service. - ServerSideListenerHolder: AtomicReference replaces the global lock. - SparkConnectStreamingQueryCache: two global locks are replaced with ConcurrentHashMap and a mutex-protected per-tag data structure, i.e., global locks -> a per-tag lock. ### Why are the changes needed? Spark Connect Service doesn't limit the number of threads, susceptible to priority inversion because of heavy use of global locks. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests + modified an existing test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48350 from changgyoopark-db/SPARK-49876-REMOVE-LOCKS. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../SparkConnectListenerBusListener.scala | 22 +- .../SparkConnectStreamingQueryCache.scala | 239 ++++++++++-------- ...SparkConnectListenerBusListenerSuite.scala | 3 +- ...SparkConnectStreamingQueryCacheSuite.scala | 14 +- 4 files changed, 160 insertions(+), 118 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 7a0c067ab430b..445f40d25edcd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.atomic.AtomicReference import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,7 +42,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { // The server side listener that is responsible to stream streaming query events back to client. // There is only one listener per sessionHolder, but each listener is responsible for all events // of all streaming queries in the SparkSession. - var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None + var streamingQueryServerSideListener: AtomicReference[SparkConnectListenerBusListener] = + new AtomicReference() // The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent. // Events for corresponding query will be sent back to client with // the WriteStreamOperationStart response, so that the client can handle the event before @@ -50,10 +52,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { val streamingQueryStartedEventCache : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap() - val lock = new Object() - - def isServerSideListenerRegistered: Boolean = lock.synchronized { - streamingQueryServerSideListener.isDefined + def isServerSideListenerRegistered: Boolean = { + streamingQueryServerSideListener.getAcquire() != null } /** @@ -65,10 +65,10 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * @param responseObserver * the responseObserver created from the first long running executeThread. */ - def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized { + def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val serverListener = new SparkConnectListenerBusListener(this, responseObserver) sessionHolder.session.streams.addListener(serverListener) - streamingQueryServerSideListener = Some(serverListener) + streamingQueryServerSideListener.setRelease(serverListener) } /** @@ -77,13 +77,13 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * exception. It removes the listener from the session, clears the cache. Also it sends back the * final ResultComplete response. */ - def cleanUp(): Unit = lock.synchronized { - streamingQueryServerSideListener.foreach { listener => + def cleanUp(): Unit = { + var listener = streamingQueryServerSideListener.getAndSet(null) + if (listener != null) { sessionHolder.session.streams.removeListener(listener) listener.sendResultComplete() + streamingQueryStartedEventCache.clear() } - streamingQueryStartedEventCache.clear() - streamingQueryServerSideListener = None } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 48492bac62344..3da2548b456e8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.Executors -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} @@ -61,36 +58,34 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, query: StreamingQuery, tags: Set[String], - operationId: String): Unit = queryCacheLock.synchronized { - taggedQueriesLock.synchronized { - val value = QueryCacheValue( - userId = sessionHolder.userId, - sessionId = sessionHolder.sessionId, - session = sessionHolder.session, - query = query, - operationId = operationId, - expiresAtMs = None) - - val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) - tags.foreach { tag => - taggedQueries - .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]) - .addOne(queryKey) - } - - queryCache.put(queryKey, value) match { - case Some(existing) => // Query is being replace. Not really expected. + operationId: String): Unit = { + val value = QueryCacheValue( + userId = sessionHolder.userId, + sessionId = sessionHolder.sessionId, + session = sessionHolder.session, + query = query, + operationId = operationId, + expiresAtMs = None) + + val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + + queryCache.compute( + queryKey, + (key, existing) => { + if (existing != null) { // The query is being replaced: allowed, though not expected. logWarning(log"Replacing existing query in the cache (unexpected). " + log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " + log"new value ${MDC(NEW_VALUE, value)}.") - case None => + } else { logInfo( log"Adding new query to the cache. Query Id ${MDC(QUERY_ID, query.id)}, " + log"value ${MDC(QUERY_CACHE_VALUE, value)}.") - } + } + value + }) - schedulePeriodicChecks() // Starts the scheduler thread if it hasn't started. - } + schedulePeriodicChecks() // Start the scheduler thread if it has not been started. } /** @@ -104,44 +99,35 @@ private[connect] class SparkConnectStreamingQueryCache( runId: String, tags: Set[String], session: SparkSession): Option[QueryCacheValue] = { - taggedQueriesLock.synchronized { - val key = QueryCacheKey(queryId, runId) - val result = getCachedQuery(QueryCacheKey(queryId, runId), session) - tags.foreach { tag => - taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key) - } - result - } + val queryKey = QueryCacheKey(queryId, runId) + val result = getCachedQuery(QueryCacheKey(queryId, runId), session) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + result } /** * Similar with [[getCachedQuery]] but it gets queries tagged previously. */ def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = { - taggedQueriesLock.synchronized { - taggedQueries - .get(tag) - .map { k => - k.flatMap(getCachedQuery(_, session)).toSeq - } - .getOrElse(Seq.empty[QueryCacheValue]) - } + val queryKeySet = Option(taggedQueries.get(tag)) + queryKeySet + .map(_.flatMap(k => getCachedQuery(k, session))) + .getOrElse(Seq.empty[QueryCacheValue]) } private def getCachedQuery( key: QueryCacheKey, session: SparkSession): Option[QueryCacheValue] = { - queryCacheLock.synchronized { - queryCache.get(key).flatMap { v => - if (v.session == session) { - v.expiresAtMs.foreach { _ => - // Extend the expiry time as the client is accessing it. - val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis - queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) - } - Some(v) - } else None // Should be rare, may be client is trying access from a different session. - } + val value = Option(queryCache.get(key)) + value.flatMap { v => + if (v.session == session) { + v.expiresAtMs.foreach { _ => + // Extend the expiry time as the client is accessing it. + val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis + queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) + } + Some(v) + } else None // Should be rare, may be client is trying access from a different session. } } @@ -154,7 +140,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, blocking: Boolean = true): Seq[String] = { val operationIds = new mutable.ArrayBuffer[String]() - for ((k, v) <- queryCache) { + queryCache.forEach((k, v) => { if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { logInfo( @@ -178,29 +164,27 @@ private[connect] class SparkConnectStreamingQueryCache( } } } - } + }) operationIds.toSeq } // Visible for testing private[service] def getCachedValue(queryId: String, runId: String): Option[QueryCacheValue] = - queryCache.get(QueryCacheKey(queryId, runId)) + Option(queryCache.get(QueryCacheKey(queryId, runId))) // Visible for testing. - private[service] def shutdown(): Unit = queryCacheLock.synchronized { + private[service] def shutdown(): Unit = { val executor = scheduledExecutor.getAndSet(null) if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } } - @GuardedBy("queryCacheLock") - private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue] - private val queryCacheLock = new Object + private val queryCache: ConcurrentMap[QueryCacheKey, QueryCacheValue] = + new ConcurrentHashMap[QueryCacheKey, QueryCacheValue] - @GuardedBy("queryCacheLock") - private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] - private val taggedQueriesLock = new Object + private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] = + new ConcurrentHashMap[String, QueryCacheKeySet] private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = new AtomicReference[ScheduledExecutorService]() @@ -228,62 +212,109 @@ private[connect] class SparkConnectStreamingQueryCache( } } + private def addTaggedQuery(tag: String, queryKey: QueryCacheKey): Unit = { + taggedQueries.compute( + tag, + (k, v) => { + if (v == null || !v.addKey(queryKey)) { + // Create a new QueryCacheKeySet if the entry is absent or being removed. + var keys = mutable.HashSet.empty[QueryCacheKey] + keys.add(queryKey) + new QueryCacheKeySet(keys = keys) + } else { + v + } + }) + } + /** * Periodic maintenance task to do the following: * - Update status of query if it is inactive. Sets an expiry time for such queries * - Drop expired queries from the cache. */ - private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized { + private def periodicMaintenance(): Unit = { + val nowMs = clock.getTimeMillis() - queryCacheLock.synchronized { - val nowMs = clock.getTimeMillis() + queryCache.forEach((k, v) => { + val id = k.queryId + val runId = k.runId + v.expiresAtMs match { - for ((k, v) <- queryCache) { - val id = k.queryId - val runId = k.runId - v.expiresAtMs match { + case Some(ts) if nowMs >= ts => // Expired. Drop references. + logInfo( + log"Removing references for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") + queryCache.remove(k) - case Some(ts) if nowMs >= ts => // Expired. Drop references. - logInfo( - log"Removing references for id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") - queryCache.remove(k) + case Some(_) => // Inactive query waiting for expiration. Do nothing. + logInfo( + log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)}") + + case None => // Active query, check if it is stopped. Enable timeout if it is stopped. + val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - case Some(_) => // Inactive query waiting for expiration. Do nothing. + if (!isActive) { logInfo( - log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"Marking query id: ${MDC(QUERY_ID, id)} " + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)}") - - case None => // Active query, check if it is stopped. Enable timeout if it is stopped. - val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - - if (!isActive) { - logInfo( - log"Marking query id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") - val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis - queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) - // To consider: Clean up any runner registered for this query with the session holder - // for this session. Useful in case listener events are delayed (such delays are - // seen in practice, especially when users have heavy processing inside listeners). - // Currently such workers would be cleaned up when the connect session expires. - } - } + log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") + val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis + queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) + // To consider: Clean up any runner registered for this query with the session holder + // for this session. Useful in case listener events are delayed (such delays are + // seen in practice, especially when users have heavy processing inside listeners). + // Currently such workers would be cleaned up when the connect session expires. + } } + }) - taggedQueries.toArray.foreach { case (key, value) => - value.zipWithIndex.toArray.foreach { case (queryKey, i) => - if (queryCache.contains(queryKey)) { - value.remove(i) - } + // Removes any tagged queries that do not correspond to cached queries. + taggedQueries.forEach((key, value) => { + if (value.filter(k => queryCache.containsKey(k))) { + taggedQueries.remove(key, value) + } + }) + } + + case class QueryCacheKeySet(keys: mutable.HashSet[QueryCacheKey]) { + + /** Tries to add the key if the set is not empty, otherwise returns false. */ + def addKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.isEmpty) { + // The entry is about to be removed. + return false } + keys.add(key) + true + } + } - if (value.isEmpty) { - taggedQueries.remove(key) + /** Removes the key and returns true if the set is empty. */ + def removeKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.remove(key)) { + return keys.isEmpty } + false + } + } + + /** Removes entries that do not satisfy the predicate. */ + def filter(pred: QueryCacheKey => Boolean): Boolean = { + keys.synchronized { + keys.filterInPlace(k => pred(k)) + keys.isEmpty + } + } + + /** Iterates over entries, apply the function individually, and then flatten the result. */ + def flatMap[T](function: QueryCacheKey => Option[T]): Seq[T] = { + keys.synchronized { + keys.flatMap(k => function(k)).toSeq } } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala index d856ffaabc316..2404dea21d91e 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala @@ -202,7 +202,8 @@ class SparkConnectListenerBusListenerSuite val listenerHolder = sessionHolder.streamingServersideListenerHolder eventually(timeout(5.seconds), interval(500.milliseconds)) { assert( - sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty) + sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.get() == + null) assert(spark.streams.listListeners().size === listenerCntBeforeThrow) assert(listenerHolder.streamingQueryStartedEventCache.isEmpty) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index 512a0a80c4a91..729a995f46145 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -48,6 +48,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug val queryId = UUID.randomUUID().toString val runId = UUID.randomUUID().toString + val tag = "test_tag" val mockSession = mock[SparkSession] val mockQuery = mock[StreamingQuery] val mockStreamingQueryManager = mock[StreamingQueryManager] @@ -67,13 +68,16 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug // Register the query. - sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set.empty[String], "") + sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set(tag), "") sessionMgr.getCachedValue(queryId, runId) match { case Some(v) => assert(v.sessionId == sessionHolder.sessionId) assert(v.expiresAtMs.isEmpty, "No expiry time should be set for active query") + val taggedQueries = sessionMgr.getTaggedQuery(tag, mockSession) + assert(taggedQueries.contains(v)) + case None => assert(false, "Query should be found") } @@ -127,6 +131,9 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug assert(sessionMgr.getCachedValue(queryId, runId).map(_.query).contains(mockQuery)) assert( sessionMgr.getCachedValue(queryId, restartedRunId).map(_.query).contains(restartedQuery)) + eventually(timeout(1.minute)) { + assert(sessionMgr.taggedQueries.containsKey(tag)) + } // Advance time by 1 minute and verify the first query is dropped from the cache. clock.advance(1.minute.toMillis) @@ -144,8 +151,11 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug clock.advance(1.minute.toMillis) eventually(timeout(1.minute)) { assert(sessionMgr.getCachedValue(queryId, restartedRunId).isEmpty) + assert(sessionMgr.getTaggedQuery(tag, mockSession).isEmpty) + } + eventually(timeout(1.minute)) { + assert(!sessionMgr.taggedQueries.containsKey(tag)) } - sessionMgr.shutdown() } } From e374b94a9c8b217156ce24137efbd404a38e4f21 Mon Sep 17 00:00:00 2001 From: Ziqi Liu Date: Thu, 17 Oct 2024 12:24:59 +0800 Subject: [PATCH 014/108] [SPARK-49979][SQL] Fix AQE hanging issue when collecting twice on a failed plan ### What changes were proposed in this pull request? Record failure/error status in query stage. And abort immediately upon seeing failed query stage when creating new query stages. ### Why are the changes needed? AQE has a potential hanging issue when we collect twice from a failed AQE plan, no new query stage will be created, and no stage will be submitted either. We will be waiting for a finish event forever, which will never come because that query stage has already failed in the previous run. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? New UT. ### Was this patch authored or co-authored using generative AI tooling? NO Closes #48484 from liuzqt/SPARK-49979. Authored-by: Ziqi Liu Signed-off-by: Wenchen Fan --- .../adaptive/AdaptiveSparkPlanExec.scala | 12 +++++++++++ .../execution/adaptive/QueryStageExec.scala | 9 ++++++++ .../adaptive/AdaptiveQueryExecSuite.scala | 21 +++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index ffab67b7cae24..77efc4793359f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -340,6 +340,7 @@ case class AdaptiveSparkPlanExec( }(AdaptiveSparkPlanExec.executionContext) } catch { case e: Throwable => + stage.error.set(Some(e)) cleanUpAndThrowException(Seq(e), Some(stage.id)) } } @@ -355,6 +356,7 @@ case class AdaptiveSparkPlanExec( case StageSuccess(stage, res) => stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => + stage.error.set(Some(ex)) errors.append(ex) } @@ -600,6 +602,7 @@ case class AdaptiveSparkPlanExec( newStages = Seq(newStage)) case q: QueryStageExec => + assertStageNotFailed(q) CreateStageResult(newPlan = q, allChildStagesMaterialized = q.isMaterialized, newStages = Seq.empty) @@ -815,6 +818,15 @@ case class AdaptiveSparkPlanExec( } } + private def assertStageNotFailed(stage: QueryStageExec): Unit = { + if (stage.hasFailed) { + throw stage.error.get().get match { + case fatal: SparkFatalException => fatal.throwable + case other => other + } + } + } + /** * Cancel all running stages with best effort and throw an Exception containing all stage * materialization errors and stage cancellation errors. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 51595e20ae5f8..2391fe740118d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -93,6 +93,13 @@ abstract class QueryStageExec extends LeafExecNode { private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption final def isMaterialized: Boolean = resultOption.get().isDefined + @transient + @volatile + protected var _error = new AtomicReference[Option[Throwable]](None) + + def error: AtomicReference[Option[Throwable]] = _error + final def hasFailed: Boolean = _error.get().isDefined + override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning override def outputOrdering: Seq[SortOrder] = plan.outputOrdering @@ -203,6 +210,7 @@ case class ShuffleQueryStageExec( ReusedExchangeExec(newOutput, shuffle), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } @@ -249,6 +257,7 @@ case class BroadcastQueryStageExec( ReusedExchangeExec(newOutput, broadcast), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 8e9ba6c8e21d8..1df045764d8b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -3065,6 +3065,27 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-49979: AQE hang forever when collecting twice on a failed AQE plan") { + val func: Long => Boolean = (i : Long) => { + throw new Exception("SPARK-49979") + } + withUserDefinedFunction("func" -> true) { + spark.udf.register("func", func) + val df1 = spark.range(1024).select($"id".as("key1")) + val df2 = spark.range(2048).select($"id".as("key2")) + .withColumn("group_key", $"key2" % 1024) + val df = df1.filter(expr("func(key1)")).hint("MERGE").join(df2, $"key1" === $"key2") + .groupBy($"group_key").agg("key1" -> "count") + intercept[Throwable] { + df.collect() + } + // second collect should not hang forever + intercept[Throwable] { + df.collect() + } + } + } } /** From 175d56310fae187247dc240ed6694ea667201cf2 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Thu, 17 Oct 2024 12:27:02 +0800 Subject: [PATCH 015/108] [SPARK-49977][SQL] Use stack-based iterative computation to avoid creating many Scala List objects for deep expression trees ### What changes were proposed in this pull request? In some use cases with deep expression trees, the driver's heap shows many `scala.collection.immutable.$colon$colon` objects from the heap. The objects are allocated due to deep recursion in the `gatherCommutative` method which uses `flatmap` recursively. Each invocation of `flatmap` creates a new temporary Scala collection. Our claim is based on the following stack trace (>1K lines) of a thread in the driver below, truncated here for brevity: ``` "HiveServer2-Background-Pool: Thread-9867" #9867 daemon prio=5 os_prio=0 tid=0x00007f35080bf000 nid=0x33e7 runnable [0x00007f3393372000] java.lang.Thread.State: RUNNABLE at scala.collection.immutable.List$Appender$1.apply(List.scala:350) at scala.collection.immutable.List$Appender$1.apply(List.scala:341) at scala.collection.immutable.List.flatMap(List.scala:431) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.gatherCommutative(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.$anonfun$gatherCommutative$1(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression$$Lambda$5280/143713747.apply(Unknown Source) at scala.collection.immutable.List.flatMap(List.scala:366) .... at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.gatherCommutative(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.$anonfun$gatherCommutative$1(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression$$Lambda$5280/143713747.apply(Unknown Source) at scala.collection.immutable.List.flatMap(List.scala:366) .... ``` This PR fixes the issue by using a stack-based iterative computation, completely avoiding the creation of temporary Scala objects. ### Why are the changes needed? Reduce heap usage of the driver ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests, refactor ### Was this patch authored or co-authored using generative AI tooling? No Closes #48481 from utkarsh39/SPARK-49977. Lead-authored-by: Utkarsh Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/Expression.scala | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a57ba2aaa569..bb32e518ec39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1347,9 +1347,21 @@ trait CommutativeExpression extends Expression { /** Collects adjacent commutative operations. */ private def gatherCommutative( e: Expression, - f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = e match { - case c: CommutativeExpression if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) - case other => other.canonicalized :: Nil + f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = { + val resultBuffer = scala.collection.mutable.Buffer[Expression]() + val stack = scala.collection.mutable.Stack[Expression](e) + + // [SPARK-49977]: Use iterative approach to avoid creating many temporary List objects + // for deep expression trees through recursion. + while (stack.nonEmpty) { + stack.pop() match { + case c: CommutativeExpression if f.isDefinedAt(c) => + stack.pushAll(f(c)) + case other => + resultBuffer += other.canonicalized + } + } + resultBuffer.toSeq } /** From c8bedb49e9b95076ec193d4a8e80c55563bda7b3 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 13:55:49 +0900 Subject: [PATCH 016/108] [SPARK-49848][PYTHON][CONNECT] API compatibility check for Catalog ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Catalog This PR also has some refactored to make it easier to add future tests into `ConnectCompatibilityTests` ### Why are the changes needed? To ensure the compatibility between classic and connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #48364 from itholic/compat_catalog. Lead-authored-by: Haejoon Lee Co-authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- python/pyspark/sql/catalog.py | 8 +- python/pyspark/sql/connect/catalog.py | 4 +- .../sql/tests/test_connect_compatibility.py | 299 +++++++++++------- 3 files changed, 192 insertions(+), 119 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 143ce0c89f49c..8c35aafa7066c 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from pyspark.sql._typing import UserDefinedFunctionLike - from pyspark.sql.types import DataType + from pyspark.sql._typing import DataTypeOrString class CatalogMetadata(NamedTuple): @@ -759,7 +759,7 @@ def createExternalTable( source: Optional[str] = None, schema: Optional[StructType] = None, **options: str, - ) -> DataFrame: + ) -> "DataFrame": """Creates a table based on the dataset in a data source. It returns the DataFrame associated with the external table. @@ -791,7 +791,7 @@ def createTable( schema: Optional[StructType] = None, description: Optional[str] = None, **options: str, - ) -> DataFrame: + ) -> "DataFrame": """Creates a table based on the dataset in a data source. .. versionadded:: 2.2.0 @@ -942,7 +942,7 @@ def dropGlobalTempView(self, viewName: str) -> bool: return self._jcatalog.dropGlobalTempView(viewName) def registerFunction( - self, name: str, f: Callable[..., Any], returnType: Optional["DataType"] = None + self, name: str, f: Callable[..., Any], returnType: Optional["DataTypeOrString"] = None ) -> "UserDefinedFunctionLike": """An alias for :func:`spark.udf.register`. See :meth:`pyspark.sql.UDFRegistration.register`. diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index 016c13e454c8d..85bf3caaf94d7 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -234,7 +234,7 @@ def createExternalTable( source: Optional[str] = None, schema: Optional[StructType] = None, **options: str, - ) -> DataFrame: + ) -> "DataFrame": warnings.warn( "createExternalTable is deprecated since Spark 4.0, please use createTable instead.", FutureWarning, @@ -251,7 +251,7 @@ def createTable( schema: Optional[StructType] = None, description: Optional[str] = None, **options: str, - ) -> DataFrame: + ) -> "DataFrame": if schema is not None and not isinstance(schema, StructType): raise PySparkTypeError( errorClass="NOT_STRUCT", diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index dfa0fa63b2dd5..c125c905604e9 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -17,17 +17,20 @@ import unittest import inspect +import functools from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame from pyspark.sql.classic.column import Column as ClassicColumn from pyspark.sql.session import SparkSession as ClassicSparkSession +from pyspark.sql.catalog import Catalog as ClassicCatalog if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.sql.connect.column import Column as ConnectColumn from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + from pyspark.sql.connect.catalog import Catalog as ConnectCatalog class ConnectCompatibilityTestsMixin: @@ -35,8 +38,9 @@ def get_public_methods(self, cls): """Get public methods of a class.""" return { name: method - for name, method in inspect.getmembers(cls, predicate=inspect.isfunction) - if not name.startswith("_") + for name, method in inspect.getmembers(cls) + if (inspect.isfunction(method) or isinstance(method, functools._lru_cache_wrapper)) + and not name.startswith("_") } def get_public_properties(self, cls): @@ -44,134 +48,203 @@ def get_public_properties(self, cls): return { name: member for name, member in inspect.getmembers(cls) - if isinstance(member, property) and not name.startswith("_") + if (isinstance(member, property) or isinstance(member, functools.cached_property)) + and not name.startswith("_") } - def test_signature_comparison_between_classic_and_connect(self): - def compare_method_signatures(classic_cls, connect_cls, cls_name): - """Compare method signatures between classic and connect classes.""" - classic_methods = self.get_public_methods(classic_cls) - connect_methods = self.get_public_methods(connect_cls) - - common_methods = set(classic_methods.keys()) & set(connect_methods.keys()) - - for method in common_methods: - classic_signature = inspect.signature(classic_methods[method]) - connect_signature = inspect.signature(connect_methods[method]) - - # createDataFrame cannot be the same since RDD is not supported from Spark Connect - if not method == "createDataFrame": - self.assertEqual( - classic_signature, - connect_signature, - f"Signature mismatch in {cls_name} method '{method}'\n" - f"Classic: {classic_signature}\n" - f"Connect: {connect_signature}", - ) - - # DataFrame API signature comparison - compare_method_signatures(ClassicDataFrame, ConnectDataFrame, "DataFrame") - - # Column API signature comparison - compare_method_signatures(ClassicColumn, ConnectColumn, "Column") - - # SparkSession API signature comparison - compare_method_signatures(ClassicSparkSession, ConnectSparkSession, "SparkSession") - - def test_property_comparison_between_classic_and_connect(self): - def compare_property_lists(classic_cls, connect_cls, cls_name, expected_missing_properties): - """Compare properties between classic and connect classes.""" - classic_properties = self.get_public_properties(classic_cls) - connect_properties = self.get_public_properties(connect_cls) - - # Identify missing properties - classic_only_properties = set(classic_properties.keys()) - set( - connect_properties.keys() - ) - - # Compare the actual missing properties with the expected ones - self.assertEqual( - classic_only_properties, - expected_missing_properties, - f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}", - ) - - # Expected missing properties for DataFrame - expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"} - - # DataFrame properties comparison - compare_property_lists( - ClassicDataFrame, - ConnectDataFrame, - "DataFrame", - expected_missing_properties_for_dataframe, + def compare_method_signatures(self, classic_cls, connect_cls, cls_name): + """Compare method signatures between classic and connect classes.""" + classic_methods = self.get_public_methods(classic_cls) + connect_methods = self.get_public_methods(connect_cls) + + common_methods = set(classic_methods.keys()) & set(connect_methods.keys()) + + for method in common_methods: + classic_signature = inspect.signature(classic_methods[method]) + connect_signature = inspect.signature(connect_methods[method]) + + if not method == "createDataFrame": + self.assertEqual( + classic_signature, + connect_signature, + f"Signature mismatch in {cls_name} method '{method}'\n" + f"Classic: {classic_signature}\n" + f"Connect: {connect_signature}", + ) + + def compare_property_lists( + self, + classic_cls, + connect_cls, + cls_name, + expected_missing_connect_properties, + expected_missing_classic_properties, + ): + """Compare properties between classic and connect classes.""" + classic_properties = self.get_public_properties(classic_cls) + connect_properties = self.get_public_properties(connect_cls) + + # Identify missing properties + classic_only_properties = set(classic_properties.keys()) - set(connect_properties.keys()) + connect_only_properties = set(connect_properties.keys()) - set(classic_properties.keys()) + + # Compare the actual missing properties with the expected ones + self.assertEqual( + classic_only_properties, + expected_missing_connect_properties, + f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}", ) - # Expected missing properties for Column (if any, replace with actual values) - expected_missing_properties_for_column = set() - - # Column properties comparison - compare_property_lists( - ClassicColumn, ConnectColumn, "Column", expected_missing_properties_for_column + # Reverse compatibility check + self.assertEqual( + connect_only_properties, + expected_missing_classic_properties, + f"{cls_name}: Unexpected missing properties in Classic: {connect_only_properties}", ) - # Expected missing properties for SparkSession - expected_missing_properties_for_spark_session = {"sparkContext", "version"} - - # SparkSession properties comparison - compare_property_lists( - ClassicSparkSession, - ConnectSparkSession, - "SparkSession", - expected_missing_properties_for_spark_session, + def check_missing_methods( + self, + classic_cls, + connect_cls, + cls_name, + expected_missing_connect_methods, + expected_missing_classic_methods, + ): + """Check for expected missing methods between classic and connect classes.""" + classic_methods = self.get_public_methods(classic_cls) + connect_methods = self.get_public_methods(connect_cls) + + # Identify missing methods + classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys()) + connect_only_methods = set(connect_methods.keys()) - set(classic_methods.keys()) + + # Compare the actual missing methods with the expected ones + self.assertEqual( + classic_only_methods, + expected_missing_connect_methods, + f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}", ) - def test_missing_methods(self): - def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_methods): - """Check for expected missing methods between classic and connect classes.""" - classic_methods = self.get_public_methods(classic_cls) - connect_methods = self.get_public_methods(connect_cls) - - # Identify missing methods - classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys()) - - # Compare the actual missing methods with the expected ones - self.assertEqual( - classic_only_methods, - expected_missing_methods, - f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}", - ) - - # Expected missing methods for DataFrame - expected_missing_methods_for_dataframe = { - "inputFiles", - "isLocal", - "semanticHash", - "isEmpty", - } - - # DataFrame missing method check - check_missing_methods( - ClassicDataFrame, ConnectDataFrame, "DataFrame", expected_missing_methods_for_dataframe + # Reverse compatibility check + self.assertEqual( + connect_only_methods, + expected_missing_classic_methods, + f"{cls_name}: Unexpected missing methods in Classic: {connect_only_methods}", ) - # Expected missing methods for Column (if any, replace with actual values) - expected_missing_methods_for_column = set() + def check_compatibility( + self, + classic_cls, + connect_cls, + cls_name, + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ): + """ + Main method for checking compatibility between classic and connect. + + This method performs the following checks: + - API signature comparison between classic and connect classes. + - Property comparison, identifying any missing properties between classic and connect. + - Method comparison, identifying any missing methods between classic and connect. + + Parameters + ---------- + classic_cls : type + The classic class to compare. + connect_cls : type + The connect class to compare. + cls_name : str + The name of the class. + expected_missing_connect_properties : set + A set of properties expected to be missing in the connect class. + expected_missing_classic_properties : set + A set of properties expected to be missing in the classic class. + expected_missing_connect_methods : set + A set of methods expected to be missing in the connect class. + expected_missing_classic_methods : set + A set of methods expected to be missing in the classic class. + """ + self.compare_method_signatures(classic_cls, connect_cls, cls_name) + self.compare_property_lists( + classic_cls, + connect_cls, + cls_name, + expected_missing_connect_properties, + expected_missing_classic_properties, + ) + self.check_missing_methods( + classic_cls, + connect_cls, + cls_name, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) - # Column missing method check - check_missing_methods( - ClassicColumn, ConnectColumn, "Column", expected_missing_methods_for_column + def test_dataframe_compatibility(self): + """Test DataFrame compatibility between classic and connect.""" + expected_missing_connect_properties = {"sql_ctx"} + expected_missing_classic_properties = {"is_cached"} + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicDataFrame, + ConnectDataFrame, + "DataFrame", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, ) - # Expected missing methods for SparkSession (if any, replace with actual values) - expected_missing_methods_for_spark_session = {"newSession"} + def test_column_compatibility(self): + """Test Column compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = {"to_plan"} + self.check_compatibility( + ClassicColumn, + ConnectColumn, + "Column", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) - # SparkSession missing method check - check_missing_methods( + def test_spark_session_compatibility(self): + """Test SparkSession compatibility between classic and connect.""" + expected_missing_connect_properties = {"sparkContext"} + expected_missing_classic_properties = {"is_stopped", "session_id"} + expected_missing_connect_methods = {"newSession"} + expected_missing_classic_methods = set() + self.check_compatibility( ClassicSparkSession, ConnectSparkSession, "SparkSession", - expected_missing_methods_for_spark_session, + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + + def test_catalog_compatibility(self): + """Test Catalog compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicCatalog, + ConnectCatalog, + "Catalog", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, ) From 74dbd9ab9cff0842be19ea0d800e55cff7612f12 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 17 Oct 2024 13:50:57 +0800 Subject: [PATCH 017/108] [SPARK-49996][SQL][TESTS] Upgrade `mysql-connector-j` to 9.1.0 ### What changes were proposed in this pull request? The pr aims to upgrade `mysql-connector-j` from `9.0.0` to `9.1.0`. ### Why are the changes needed? The full release notes of `mysql-connector-j` 9.1.0: https://dev.mysql.com/doc/relnotes/connector-j/en/news-9-1-0.html ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48504 from panbingkun/SPARK-49996. Authored-by: panbingkun Signed-off-by: Kent Yao --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index cab7f7f595434..ff15f200e2bb1 100644 --- a/pom.xml +++ b/pom.xml @@ -327,7 +327,7 @@ -Dio.netty.tryReflectionSetAccessible=true 2.7.12 - 9.0.0 + 9.1.0 42.7.4 11.5.9.0 12.8.1.jre11 From 8405c9b57a1f4f94f8bb4d8ed5cb11e071edeb7e Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Thu, 17 Oct 2024 15:26:30 +0900 Subject: [PATCH 018/108] [SPARK-49770][SS][ROCKSDB HARDENING] Improve RocksDB SST file mapping management, and fix issue with reloading same version with existing snapshot ### What changes were proposed in this pull request? Currently, we have a scenario where if a version X is loaded (and there is an existing snapshot with version X), Spark will reuse the SST files from the existing Snapshot resulting in a VersionID Mismatch error. This PR fixes this issue, and simplifies RocksDB state management. The change eliminates the majority of shared state between Task thread and Maintenance thread, simplifying the implementation. With this change, the task thread is now solely responsible for keeping track of the local files to DFS file mapping, the maintenance thread will not access this mapping. The DFS file names are generated at `commit()` - the generated snapshot, and the mapping containing the new SST files (with their generated DFS names) is handed over to the maintenance thread. The latest generated snapshot will be appended to a ConcurrentLinkedQueue (named snapshotsToUploadQueue). The maintenance thread polls from the snapshotsToUploadQueue repeatedly until its empty. For the last snapshot polled out of the queue, the maintenance thread will upload the new SST files. The maintenance thread will also clear all removed snapshot objects from the disk. ### Why are the changes needed? These changes fix an issue where Spark Streaming fails with RocksDB VersionIdMismatch error if there is a existing snapshot (not yet uploaded) for RocksDB version being loaded. In this scenario, SST files from existing snapshot are reused resulting in a versionId Mismatch error. In short, these changes: 1. Remove shared fileMapping between RocksDB Maintenance thread and task thread, simplifying the file mapping logic. The file Mapping is only modified from task thread after this change. 2. Fixes the issue where RocksDB SST files from current snapshot (with same version) are reused, resulting in RocksDB VersionId Mismatch. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? 1. All existing testcases pass. 2. Added new testcases suggested in https://github.com/apache/spark/pull/47850/files, and ensure they pass with these changes. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47875 from sahnib/master. Lead-authored-by: Bhuwan Sahni Co-authored-by: micheal-o Co-authored-by: Bhuwan Sahni Signed-off-by: Jungtaek Lim --- .../execution/streaming/state/RocksDB.scala | 336 ++++++++++++------ .../streaming/state/RocksDBFileManager.scala | 195 ++++------ .../streaming/state/RocksDBSuite.scala | 155 +++++--- 3 files changed, 405 insertions(+), 281 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index c7f8434e5345b..99f8e7b8f36e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.util.Locale -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.Set +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, TimeUnit} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import javax.annotation.concurrent.GuardedBy import scala.collection.{mutable, Map} -import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters.{ConcurrentMapHasAsScala, MapHasAsJava} import scala.ref.WeakReference import scala.util.Try @@ -51,7 +52,6 @@ case object RollbackStore extends RocksDBOpType("rollback_store") case object CloseStore extends RocksDBOpType("close_store") case object ReportStoreMetrics extends RocksDBOpType("report_store_metrics") case object StoreTaskCompletionListener extends RocksDBOpType("store_task_completion_listener") -case object StoreMaintenance extends RocksDBOpType("store_maintenance") /** * Class representing a RocksDB instance that checkpoints version of data to DFS. @@ -74,21 +74,22 @@ class RocksDB( loggingId: String = "", useColumnFamilies: Boolean = false) extends Logging { + import RocksDB._ + case class RocksDBSnapshot( checkpointDir: File, version: Long, numKeys: Long, - capturedFileMappings: RocksDBFileMappings, columnFamilyMapping: Map[String, Short], - maxColumnFamilyId: Short) { + maxColumnFamilyId: Short, + dfsFileSuffix: String, + fileMapping: Map[String, RocksDBSnapshotFile]) { def close(): Unit = { silentDeleteRecursively(checkpointDir, s"Free up local checkpoint of snapshot $version") } } - @volatile private var latestSnapshot: Option[RocksDBSnapshot] = None @volatile private var lastSnapshotVersion = 0L - private val oldSnapshots = new ListBuffer[RocksDBSnapshot] RocksDBLoader.loadLibrary() @@ -150,6 +151,7 @@ class RocksDB( hadoopConf, conf.compressionCodec, loggingId = loggingId) private val byteArrayPair = new ByteArrayPair() private val commitLatencyMs = new mutable.HashMap[String, Long]() + private val acquireLock = new Object @volatile private var db: NativeRocksDB = _ @@ -249,6 +251,16 @@ class RocksDB( } } + // Mapping of local SST files to DFS files for file reuse. + // This mapping should only be updated using the Task thread - at version load and commit time. + // If same mapping instance is updated from different threads, + // it will result in undefined behavior (and most likely incorrect mapping state). + @GuardedBy("acquireLock") + private val rocksDBFileMapping: RocksDBFileMapping = new RocksDBFileMapping() + + // We send snapshots that needs to be uploaded by the maintenance thread to this queue + private val snapshotsToUploadQueue = new ConcurrentLinkedQueue[RocksDBSnapshot]() + /** * Load the given version of data in a native RocksDB instance. * Note that this will copy all the necessary file from DFS to local disk as needed, @@ -262,12 +274,15 @@ class RocksDB( try { if (loadedVersion != version) { closeDB(ignoreException = false) - // deep copy is needed to avoid race condition - // between maintenance and task threads - fileManager.copyFileMapping() val latestSnapshotVersion = fileManager.getLatestSnapshotVersion(version) - val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion, workingDir) + rocksDBFileMapping.currentVersion = latestSnapshotVersion + val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion, + workingDir, rocksDBFileMapping) loadedVersion = latestSnapshotVersion + + // reset the last snapshot version to the latest available snapshot version + lastSnapshotVersion = latestSnapshotVersion + // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) @@ -279,21 +294,7 @@ class RocksDB( metadata.maxColumnFamilyId.foreach { maxId => maxColumnFamilyId.set(maxId) } - // reset last snapshot version - if (lastSnapshotVersion > latestSnapshotVersion) { - // discard any newer snapshots - synchronized { - if (latestSnapshot.isDefined) { - oldSnapshots += latestSnapshot.get - latestSnapshot = None - } - } - } - - // reset the last snapshot version to the latest available snapshot version - lastSnapshotVersion = latestSnapshotVersion openDB() - numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) { // we don't track the total number of rows - discard the number being track -1L @@ -369,15 +370,11 @@ class RocksDB( */ private def replayFromCheckpoint(snapshotVersion: Long, endVersion: Long): Any = { closeDB() - val metadata = fileManager.loadCheckpointFromDfs(snapshotVersion, workingDir) + rocksDBFileMapping.currentVersion = snapshotVersion + val metadata = fileManager.loadCheckpointFromDfs(snapshotVersion, + workingDir, rocksDBFileMapping) loadedVersion = snapshotVersion - - // reset last snapshot version - if (lastSnapshotVersion > snapshotVersion) { - // discard any newer snapshots - lastSnapshotVersion = 0L - latestSnapshot = None - } + lastSnapshotVersion = snapshotVersion openDB() numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) { @@ -583,12 +580,13 @@ class RocksDB( def commit(): Long = { val newVersion = loadedVersion + 1 try { - logInfo(log"Flushing updates for ${MDC(LogKeys.VERSION_NUM, newVersion)}") var compactTimeMs = 0L var flushTimeMs = 0L var checkpointTimeMs = 0L + var snapshot: Option[RocksDBSnapshot] = None + if (shouldCreateSnapshot() || shouldForceSnapshot.get()) { // Need to flush the change to disk before creating a checkpoint // because rocksdb wal is disabled. @@ -620,19 +618,9 @@ class RocksDB( // inside the uploadSnapshot() called below. // If changelog checkpointing is enabled, snapshot will be uploaded asynchronously // during state store maintenance. - synchronized { - if (latestSnapshot.isDefined) { - oldSnapshots += latestSnapshot.get - } - latestSnapshot = Some( - RocksDBSnapshot(checkpointDir, - newVersion, - numKeysOnWritingVersion, - fileManager.captureFileMapReference(), - colFamilyNameToIdMap.asScala.toMap, - maxColumnFamilyId.get().toShort)) - lastSnapshotVersion = newVersion - } + snapshot = Some(createSnapshot(checkpointDir, newVersion, + colFamilyNameToIdMap.asScala.toMap, maxColumnFamilyId.get().toShort)) + lastSnapshotVersion = newVersion } } @@ -642,8 +630,16 @@ class RocksDB( // If we have changed the columnFamilyId mapping, we have set a new // snapshot and need to upload this to the DFS even if changelog checkpointing // is enabled. + var isUploaded = false if (shouldForceSnapshot.get()) { - uploadSnapshot() + assert(snapshot.isDefined) + fileManagerMetrics = uploadSnapshot( + snapshot.get, + fileManager, + rocksDBFileMapping.snapshotsPendingUpload, + loggingId + ) + isUploaded = true shouldForceSnapshot.set(false) } @@ -651,12 +647,21 @@ class RocksDB( try { assert(changelogWriter.isDefined) changelogWriter.foreach(_.commit()) + if (!isUploaded) { + snapshot.foreach(snapshotsToUploadQueue.offer) + } } finally { changelogWriter = None } } else { assert(changelogWriter.isEmpty) - uploadSnapshot() + assert(snapshot.isDefined) + fileManagerMetrics = uploadSnapshot( + snapshot.get, + fileManager, + rocksDBFileMapping.snapshotsPendingUpload, + loggingId + ) } } @@ -696,56 +701,6 @@ class RocksDB( } else true } - private def uploadSnapshot(): Unit = { - var oldSnapshotsImmutable: List[RocksDBSnapshot] = Nil - val localCheckpoint = synchronized { - val checkpoint = latestSnapshot - latestSnapshot = None - - // Convert mutable list buffer to immutable to prevent - // race condition with commit where old snapshot is added - oldSnapshotsImmutable = oldSnapshots.toList - oldSnapshots.clear() - - checkpoint - } - localCheckpoint match { - case Some( - RocksDBSnapshot( - localDir, - version, - numKeys, - capturedFileMappings, - columnFamilyMapping, - maxColumnFamilyId)) => - try { - val uploadTime = timeTakenMs { - fileManager.saveCheckpointToDfs( - localDir, - version, - numKeys, - capturedFileMappings, - Some(columnFamilyMapping.toMap), - Some(maxColumnFamilyId) - ) - fileManagerMetrics = fileManager.latestSaveCheckpointMetrics - } - logInfo(log"${MDC(LogKeys.LOG_ID, loggingId)}: Upload snapshot of version " + - log"${MDC(LogKeys.VERSION_NUM, version)}," + - log" time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms") - } finally { - localCheckpoint.foreach(_.close()) - - // Clean up old latestSnapshots - for (snapshot <- oldSnapshotsImmutable) { - snapshot.close() - } - - } - case _ => - } - } - /** * Drop uncommitted changes, and roll back to previous version. */ @@ -762,7 +717,26 @@ class RocksDB( def doMaintenance(): Unit = { if (enableChangelogCheckpointing) { - uploadSnapshot() + + var mostRecentSnapshot: Option[RocksDBSnapshot] = None + var snapshot = snapshotsToUploadQueue.poll() + + // We only want to upload the most recent snapshot and skip the previous ones. + while (snapshot != null) { + logDebug(s"RocksDB Maintenance - polled snapshot ${snapshot.version}") + mostRecentSnapshot.foreach(_.close()) + mostRecentSnapshot = Some(snapshot) + snapshot = snapshotsToUploadQueue.poll() + } + + if (mostRecentSnapshot.isDefined) { + fileManagerMetrics = uploadSnapshot( + mostRecentSnapshot.get, + fileManager, + rocksDBFileMapping.snapshotsPendingUpload, + loggingId + ) + } } val cleanupTime = timeTakenMs { fileManager.deleteOldVersions(conf.minVersionsToRetain, conf.minVersionsToDelete) @@ -782,10 +756,13 @@ class RocksDB( flushOptions.close() rocksDbOptions.close() dbLogger.close() - synchronized { - latestSnapshot.foreach(_.close()) - latestSnapshot = None + + var snapshot = snapshotsToUploadQueue.poll() + while (snapshot != null) { + snapshot.close() + snapshot = snapshotsToUploadQueue.poll() } + silentDeleteRecursively(localRootDir, "closing RocksDB") } catch { case e: Exception => @@ -884,6 +861,18 @@ class RocksDB( rocksDBMetricsOpt } + private def createSnapshot( + checkpointDir: File, + version: Long, + columnFamilyMapping: Map[String, Short], + maxColumnFamilyId: Short): RocksDBSnapshot = { + val (dfsFileSuffix, immutableFileMapping) = rocksDBFileMapping.createSnapshotFileMapping( + fileManager, checkpointDir, version) + + RocksDBSnapshot(checkpointDir, version, numKeysOnWritingVersion, + columnFamilyMapping, maxColumnFamilyId, dfsFileSuffix, immutableFileMapping) + } + /** * Function to acquire RocksDB instance lock that allows for synchronized access to the state * store instance @@ -1005,10 +994,147 @@ class RocksDB( } } + override protected def logName: String = s"${super.logName} $loggingId" +} + +object RocksDB extends Logging { + + /** Upload the snapshot to DFS and remove it from snapshots pending */ + private def uploadSnapshot( + snapshot: RocksDB#RocksDBSnapshot, + fileManager: RocksDBFileManager, + snapshotsPendingUpload: Set[RocksDBVersionSnapshotInfo], + loggingId: String): RocksDBFileManagerMetrics = { + var fileManagerMetrics: RocksDBFileManagerMetrics = null + try { + val uploadTime = timeTakenMs { + fileManager.saveCheckpointToDfs(snapshot.checkpointDir, + snapshot.version, snapshot.numKeys, snapshot.fileMapping, + Some(snapshot.columnFamilyMapping), Some(snapshot.maxColumnFamilyId)) + fileManagerMetrics = fileManager.latestSaveCheckpointMetrics + + val snapshotInfo = RocksDBVersionSnapshotInfo(snapshot.version, snapshot.dfsFileSuffix) + // We are only removing the uploaded snapshot info from the pending set, + // to let the file mapping (i.e. query threads) know that the snapshot (i.e. and its files) + // have been uploaded to DFS. We don't touch the file mapping here to avoid corrupting it. + snapshotsPendingUpload.remove(snapshotInfo) + } + logInfo(log"${MDC(LogKeys.LOG_ID, loggingId)}: Upload snapshot of version " + + log"${MDC(LogKeys.VERSION_NUM, snapshot.version)}," + + log" time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms") + } finally { + snapshot.close() + } + + fileManagerMetrics + } + /** Records the duration of running `body` for the next query progress update. */ - protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 + private def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 +} - override protected def logName: String = s"${super.logName} $loggingId" +// uniquely identifies a Snapshot. Multiple snapshots created for same version will +// use a different dfsFilesUUID, and hence will have different RocksDBVersionSnapshotInfo +case class RocksDBVersionSnapshotInfo(version: Long, dfsFilesUUID: String) + +// Encapsulates a RocksDB immutable file, and the information whether it has been previously +// uploaded to DFS. Already uploaded files can be skipped during SST file upload. +case class RocksDBSnapshotFile(immutableFile: RocksDBImmutableFile, isUploaded: Boolean) + +// Encapsulates the mapping of local SST files to DFS files. This mapping prevents +// re-uploading the same SST file multiple times to DFS, saving I/O and reducing snapshot +// upload time. During version load, if a DFS file is already present on local file system, +// it will be reused. +// This mapping should only be updated using the Task thread - at version load and commit time. +// If same mapping instance is updated from different threads, it will result in undefined behavior +// (and most likely incorrect mapping state). +class RocksDBFileMapping { + + // Maps a local SST file to the DFS version and DFS file. + private val localFileMappings: mutable.Map[String, (Long, RocksDBImmutableFile)] = + mutable.HashMap[String, (Long, RocksDBImmutableFile)]() + + // Keeps track of all snapshots which have not been uploaded yet. This prevents Spark + // from reusing SST files which have not been yet persisted to DFS, + val snapshotsPendingUpload: Set[RocksDBVersionSnapshotInfo] = ConcurrentHashMap.newKeySet() + + // Current State Store version which has been loaded. + var currentVersion: Long = 0 + + // If the local file (with localFileName) has already been persisted to DFS, returns the + // DFS file, else returns None. + // If the currently mapped DFS file was committed in a newer version (or was generated + // in a version which has not been uploaded to DFS yet), the mapped DFS file is ignored (because + // it cannot be reused in this version). In this scenario, the local mapping to this DFS file + // will be cleared, and function will return None. + def getDfsFile( + fileManager: RocksDBFileManager, + localFileName: String): Option[RocksDBImmutableFile] = { + localFileMappings.get(localFileName).map { case (dfsFileCommitVersion, dfsFile) => + val dfsFileSuffix = fileManager.dfsFileSuffix(dfsFile) + val versionSnapshotInfo = RocksDBVersionSnapshotInfo(dfsFileCommitVersion, dfsFileSuffix) + if (dfsFileCommitVersion >= currentVersion || + snapshotsPendingUpload.contains(versionSnapshotInfo)) { + // the mapped dfs file cannot be used, delete from mapping + remove(localFileName) + None + } else { + Some(dfsFile) + } + }.getOrElse(None) + } + + private def mapToDfsFile( + localFileName: String, + dfsFile: RocksDBImmutableFile, + version: Long): Unit = { + localFileMappings.put(localFileName, (version, dfsFile)) + } + + def remove(localFileName: String): Unit = { + localFileMappings.remove(localFileName) + } + + def mapToDfsFileForCurrentVersion(localFileName: String, dfsFile: RocksDBImmutableFile): Unit = { + localFileMappings.put(localFileName, (currentVersion, dfsFile)) + } + + private def syncWithLocalState(localFiles: Seq[File]): Unit = { + val localFileNames = localFiles.map(_.getName).toSet + val deletedFiles = localFileMappings.keys.filterNot(localFileNames.contains) + + deletedFiles.foreach(localFileMappings.remove) + } + + // Generates the DFS file names for local Immutable files in checkpoint directory, and + // returns the mapping from local fileName in checkpoint directory to generated DFS file. + // If the DFS file has been previously uploaded - the snapshot file isUploaded flag is set + // to true. + def createSnapshotFileMapping( + fileManager: RocksDBFileManager, + checkpointDir: File, + version: Long): (String, Map[String, RocksDBSnapshotFile]) = { + val (localImmutableFiles, _) = fileManager.listRocksDBFiles(checkpointDir) + // UUID used to prefix files uploaded to DFS as part of commit + val dfsFilesSuffix = UUID.randomUUID().toString + val snapshotFileMapping = localImmutableFiles.map { f => + val localFileName = f.getName + val existingDfsFile = getDfsFile(fileManager, localFileName) + val dfsFile = existingDfsFile.getOrElse { + val newDfsFileName = fileManager.newDFSFileName(localFileName, dfsFilesSuffix) + val newDfsFile = RocksDBImmutableFile(localFileName, newDfsFileName, sizeBytes = f.length()) + mapToDfsFile(localFileName, newDfsFile, version) + newDfsFile + } + localFileName -> RocksDBSnapshotFile(dfsFile, existingDfsFile.isDefined) + }.toMap + + syncWithLocalState(localImmutableFiles) + + val rocksDBSnapshotInfo = RocksDBVersionSnapshotInfo(version, dfsFilesSuffix) + snapshotsPendingUpload.add(rocksDBSnapshotInfo) + (dfsFilesSuffix, snapshotFileMapping) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index 350a5797978b3..e503ea1737c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -24,8 +24,7 @@ import java.util.UUID import java.util.concurrent.ConcurrentHashMap import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.collection.{mutable, Map} import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} @@ -146,33 +145,13 @@ class RocksDBFileManager( private def codec = CompressionCodec.createCodec(sparkConf, codecName) + // This is set when a version is loaded/committed. Hence only set by a task thread. private var maxSeenVersion: Option[Long] = None + // This is set during deletion of old versions. Hence only set by a maintenance thread. private var minSeenVersion = 1L @volatile private var rootDirChecked: Boolean = false - @volatile private var fileMappings = RocksDBFileMappings( - new ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]], - new ConcurrentHashMap[String, RocksDBImmutableFile] - ) - - /** - * Make a deep copy of versionToRocksDBFiles and localFilesToDfsFiles to avoid - * current task thread from overwriting the file mapping whenever background maintenance - * thread attempts to upload a snapshot - */ - def copyFileMapping() : Unit = { - val newVersionToRocksDBFiles = new ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]] - val newLocalFilesToDfsFiles = new ConcurrentHashMap[String, RocksDBImmutableFile] - - newVersionToRocksDBFiles.putAll(fileMappings.versionToRocksDBFiles) - newLocalFilesToDfsFiles.putAll(fileMappings.localFilesToDfsFiles) - - fileMappings = RocksDBFileMappings(newVersionToRocksDBFiles, newLocalFilesToDfsFiles) - } - - def captureFileMapReference(): RocksDBFileMappings = { - fileMappings - } + private val versionToRocksDBFiles = new ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]] private def getChangelogVersion(useColumnFamilies: Boolean): Short = { val changelogVersion: Short = if (useColumnFamilies) { @@ -249,13 +228,13 @@ class RocksDBFileManager( checkpointDir: File, version: Long, numKeys: Long, - capturedFileMappings: RocksDBFileMappings, + fileMapping: Map[String, RocksDBSnapshotFile], columnFamilyMapping: Option[Map[String, Short]] = None, maxColumnFamilyId: Option[Short] = None): Unit = { logFilesInDir(checkpointDir, log"Saving checkpoint files " + log"for version ${MDC(LogKeys.VERSION_NUM, version)}") val (localImmutableFiles, localOtherFiles) = listRocksDBFiles(checkpointDir) - val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles, capturedFileMappings) + val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles, fileMapping) val metadata = RocksDBCheckpointMetadata( rocksDBFiles, numKeys, columnFamilyMapping, maxColumnFamilyId) val metadataFile = localMetadataFile(checkpointDir) @@ -286,15 +265,17 @@ class RocksDBFileManager( * ensures that only the exact files generated during checkpointing will be present in the * local directory. */ - def loadCheckpointFromDfs(version: Long, localDir: File): RocksDBCheckpointMetadata = { + def loadCheckpointFromDfs( + version: Long, + localDir: File, + rocksDBFileMapping: RocksDBFileMapping): RocksDBCheckpointMetadata = { logInfo(log"Loading checkpoint files for version ${MDC(LogKeys.VERSION_NUM, version)}") // The unique ids of SST files are checked when opening a rocksdb instance. The SST files // in larger versions can't be reused even if they have the same size and name because // they belong to another rocksdb instance. - fileMappings.versionToRocksDBFiles.keySet().removeIf(_ >= version) + versionToRocksDBFiles.keySet().removeIf(_ >= version) val metadata = if (version == 0) { if (localDir.exists) Utils.deleteRecursively(localDir) - fileMappings.localFilesToDfsFiles.clear() localDir.mkdirs() RocksDBCheckpointMetadata(Seq.empty, 0) } else { @@ -307,8 +288,8 @@ class RocksDBFileManager( val metadata = RocksDBCheckpointMetadata.readFromFile(metadataFile) logInfo(log"Read metadata for version ${MDC(LogKeys.VERSION_NUM, version)}:\n" + log"${MDC(LogKeys.METADATA_JSON, metadata.prettyJson)}") - loadImmutableFilesFromDfs(metadata.immutableFiles, localDir) - fileMappings.versionToRocksDBFiles.put(version, metadata.immutableFiles) + loadImmutableFilesFromDfs(metadata.immutableFiles, localDir, rocksDBFileMapping) + versionToRocksDBFiles.put(version, metadata.immutableFiles) metadataFile.delete() metadata } @@ -342,11 +323,11 @@ class RocksDBFileManager( if (fm.exists(path)) { val files = fm.list(path).map(_.getPath) val changelogFileVersions = files - .filter(onlyChangelogFiles.accept(_)) + .filter(onlyChangelogFiles.accept) .map(_.getName.stripSuffix(".changelog")) .map(_.toLong) val snapshotFileVersions = files - .filter(onlyZipFiles.accept(_)) + .filter(onlyZipFiles.accept) .map(_.getName.stripSuffix(".zip")) .map(_.toLong) val versions = changelogFileVersions ++ snapshotFileVersions @@ -516,9 +497,9 @@ class RocksDBFileManager( // Resolve RocksDB files for all the versions and find the max version each file is used val fileToMaxUsedVersion = new mutable.HashMap[String, Long] sortedSnapshotVersions.foreach { version => - val files = Option(fileMappings.versionToRocksDBFiles.get(version)).getOrElse { + val files = Option(versionToRocksDBFiles.get(version)).getOrElse { val newResolvedFiles = getImmutableFilesFromVersionZip(version) - fileMappings.versionToRocksDBFiles.put(version, newResolvedFiles) + versionToRocksDBFiles.put(version, newResolvedFiles) newResolvedFiles } files.foreach(f => fileToMaxUsedVersion(f.dfsFileName) = @@ -565,7 +546,7 @@ class RocksDBFileManager( val versionFile = dfsBatchZipFile(version) try { fm.delete(versionFile) - fileMappings.versionToRocksDBFiles.remove(version) + versionToRocksDBFiles.remove(version) logDebug(s"Deleted version $version") } catch { case e: Exception => @@ -591,7 +572,7 @@ class RocksDBFileManager( private def saveImmutableFilesToDfs( version: Long, localFiles: Seq[File], - capturedFileMappings: RocksDBFileMappings): Seq[RocksDBImmutableFile] = { + fileMappings: Map[String, RocksDBSnapshotFile]): Seq[RocksDBImmutableFile] = { // Get the immutable files used in previous versions, as some of those uploaded files can be // reused for this version logInfo(log"Saving RocksDB files to DFS for ${MDC(LogKeys.VERSION_NUM, version)}") @@ -601,49 +582,36 @@ class RocksDBFileManager( var filesReused = 0L val immutableFiles = localFiles.map { localFile => - val existingDfsFile = - capturedFileMappings.localFilesToDfsFiles.asScala.get(localFile.getName) - if (existingDfsFile.isDefined && existingDfsFile.get.sizeBytes == localFile.length()) { - val dfsFile = existingDfsFile.get - filesReused += 1 + val dfsFileMapping = fileMappings.get(localFile.getName) + assert(dfsFileMapping.isDefined) + val dfsFile = dfsFileMapping.get.immutableFile + val existsInDfs = dfsFileMapping.get.isUploaded + + if (existsInDfs) { logInfo(log"reusing file ${MDC(LogKeys.DFS_FILE, dfsFile)} for " + log"${MDC(LogKeys.FILE_NAME, localFile)}") - RocksDBImmutableFile(localFile.getName, dfsFile.dfsFileName, dfsFile.sizeBytes) + filesReused += 1 } else { - val localFileName = localFile.getName - val dfsFileName = newDFSFileName(localFileName) - val dfsFile = dfsFilePath(dfsFileName) // Note: The implementation of copyFromLocalFile() closes the output stream when there is // any exception while copying. So this may generate partial files on DFS. But that is // okay because until the main [version].zip file is written, those partial files are // not going to be used at all. Eventually these files should get cleared. fs.copyFromLocalFile( - new Path(localFile.getAbsoluteFile.toURI), dfsFile) + new Path(localFile.getAbsoluteFile.toURI), dfsFilePath(dfsFile.dfsFileName)) val localFileSize = localFile.length() logInfo(log"Copied ${MDC(LogKeys.FILE_NAME, localFile)} to " + log"${MDC(LogKeys.DFS_FILE, dfsFile)} - ${MDC(LogKeys.NUM_BYTES, localFileSize)} bytes") filesCopied += 1 bytesCopied += localFileSize - - val immutableDfsFile = RocksDBImmutableFile(localFile.getName, dfsFileName, localFileSize) - capturedFileMappings.localFilesToDfsFiles.put(localFileName, immutableDfsFile) - - immutableDfsFile } + + dfsFile } logInfo(log"Copied ${MDC(LogKeys.NUM_FILES_COPIED, filesCopied)} files " + log"(${MDC(LogKeys.NUM_BYTES, bytesCopied)} bytes) from local to" + log" DFS for version ${MDC(LogKeys.VERSION_NUM, version)}. " + log"${MDC(LogKeys.NUM_FILES_REUSED, filesReused)} files reused without copying.") - capturedFileMappings.versionToRocksDBFiles.put(version, immutableFiles) - - // Cleanup locally deleted files from the localFilesToDfsFiles map - // Locally, SST Files can be deleted due to RocksDB compaction. These files need - // to be removed rom the localFilesToDfsFiles map to ensure that if a older version - // regenerates them and overwrites the version.zip, SST files from the conflicting - // version (previously committed) are not reused. - removeLocallyDeletedSSTFilesFromDfsMapping(localFiles) - + versionToRocksDBFiles.put(version, immutableFiles) saveCheckpointMetrics = RocksDBFileManagerMetrics( bytesCopied = bytesCopied, filesCopied = filesCopied, @@ -658,43 +626,37 @@ class RocksDBFileManager( * necessary and non-existing files are copied from DFS. */ private def loadImmutableFilesFromDfs( - immutableFiles: Seq[RocksDBImmutableFile], localDir: File): Unit = { + immutableFiles: Seq[RocksDBImmutableFile], + localDir: File, + rocksDBFileMapping: RocksDBFileMapping): Unit = { val requiredFileNameToFileDetails = immutableFiles.map(f => f.localFileName -> f).toMap val localImmutableFiles = listRocksDBFiles(localDir)._1 - // Cleanup locally deleted files from the localFilesToDfsFiles map - // Locally, SST Files can be deleted due to RocksDB compaction. These files need - // to be removed rom the localFilesToDfsFiles map to ensure that if a older version - // regenerates them and overwrites the version.zip, SST files from the conflicting - // version (previously committed) are not reused. - removeLocallyDeletedSSTFilesFromDfsMapping(localImmutableFiles) - // Delete unnecessary local immutable files - localImmutableFiles - .foreach { existingFile => - val existingFileSize = existingFile.length() - val requiredFile = requiredFileNameToFileDetails.get(existingFile.getName) - val prevDfsFile = fileMappings.localFilesToDfsFiles.asScala.get(existingFile.getName) - val isSameFile = if (requiredFile.isDefined && prevDfsFile.isDefined) { - requiredFile.get.dfsFileName == prevDfsFile.get.dfsFileName && - existingFile.length() == requiredFile.get.sizeBytes - } else { - false - } + localImmutableFiles.foreach { existingFile => + val existingFileSize = existingFile.length() + val requiredFile = requiredFileNameToFileDetails.get(existingFile.getName) + val prevDfsFile = rocksDBFileMapping.getDfsFile(this, existingFile.getName) + val isSameFile = if (requiredFile.isDefined && prevDfsFile.isDefined) { + requiredFile.get.dfsFileName == prevDfsFile.get.dfsFileName && + existingFile.length() == requiredFile.get.sizeBytes + } else { + false + } - if (!isSameFile) { - existingFile.delete() - fileMappings.localFilesToDfsFiles.remove(existingFile.getName) - logInfo(log"Deleted local file ${MDC(LogKeys.FILE_NAME, existingFile)} " + - log"with size ${MDC(LogKeys.NUM_BYTES, existingFileSize)} mapped" + - log" to previous dfsFile ${MDC(LogKeys.DFS_FILE, prevDfsFile.getOrElse("null"))}") - } else { - logInfo(log"reusing ${MDC(LogKeys.DFS_FILE, prevDfsFile)} present at " + - log"${MDC(LogKeys.EXISTING_FILE, existingFile)} " + - log"for ${MDC(LogKeys.FILE_NAME, requiredFile)}") - } + if (!isSameFile) { + rocksDBFileMapping.remove(existingFile.getName) + existingFile.delete() + logInfo(log"Deleted local file ${MDC(LogKeys.FILE_NAME, existingFile)} " + + log"with size ${MDC(LogKeys.NUM_BYTES, existingFileSize)} mapped" + + log" to previous dfsFile ${MDC(LogKeys.DFS_FILE, prevDfsFile.getOrElse("null"))}") + } else { + logInfo(log"reusing ${MDC(LogKeys.DFS_FILE, prevDfsFile)} present at " + + log"${MDC(LogKeys.EXISTING_FILE, existingFile)} " + + log"for ${MDC(LogKeys.FILE_NAME, requiredFile)}") } + } var filesCopied = 0L var bytesCopied = 0L @@ -717,7 +679,7 @@ class RocksDBFileManager( } filesCopied += 1 bytesCopied += localFileSize - fileMappings.localFilesToDfsFiles.put(localFileName, file) + rocksDBFileMapping.mapToDfsFileForCurrentVersion(localFileName, file) logInfo(log"Copied ${MDC(LogKeys.DFS_FILE, dfsFile)} to " + log"${MDC(LogKeys.FILE_NAME, localFile)} - " + log"${MDC(LogKeys.NUM_BYTES, localFileSize)} bytes") @@ -735,19 +697,6 @@ class RocksDBFileManager( filesReused = filesReused) } - private def removeLocallyDeletedSSTFilesFromDfsMapping(localFiles: Seq[File]): Unit = { - // clean up deleted SST files from the localFilesToDfsFiles Map - val currentLocalFiles = localFiles.map(_.getName).toSet - val mappingsToClean = fileMappings.localFilesToDfsFiles.asScala - .keys - .filterNot(currentLocalFiles.contains) - - mappingsToClean.foreach { f => - logInfo(log"cleaning ${MDC(LogKeys.FILE_NAME, f)} from the localFilesToDfsFiles map") - fileMappings.localFilesToDfsFiles.remove(f) - } - } - /** Get the SST files required for a version from the version zip file in DFS */ private def getImmutableFilesFromVersionZip(version: Long): Seq[RocksDBImmutableFile] = { Utils.deleteRecursively(localTempDir) @@ -811,6 +760,19 @@ class RocksDBFileManager( s"$baseName-${UUID.randomUUID}.$extension" } + def newDFSFileName(localFileName: String, dfsFileSuffix: String): String = { + val baseName = FilenameUtils.getBaseName(localFileName) + val extension = FilenameUtils.getExtension(localFileName) + s"$baseName-$dfsFileSuffix.$extension" + } + + def dfsFileSuffix(immutableFile: RocksDBImmutableFile): String = { + val suffixStart = immutableFile.dfsFileName.indexOf('-') + val suffixEnd = immutableFile.dfsFileName.indexOf('.') + + immutableFile.dfsFileName.substring(suffixStart + 1, suffixEnd) + } + private def dfsBatchZipFile(version: Long): Path = new Path(s"$dfsRootDir/$version.zip") // We use changelog suffix intentionally so that we can tell the difference from changelog file of // HDFSBackedStateStore which is named version.delta. @@ -841,7 +803,7 @@ class RocksDBFileManager( /** * List all the RocksDB files that need be synced or recovered. */ - private def listRocksDBFiles(localDir: File): (Seq[File], Seq[File]) = { + def listRocksDBFiles(localDir: File): (Seq[File], Seq[File]) = { val topLevelFiles = localDir.listFiles.filter(!_.isDirectory) val archivedLogFiles = Option(new File(localDir, LOG_FILES_LOCAL_SUBDIR).listFiles()) @@ -854,20 +816,6 @@ class RocksDBFileManager( } } -/** - * Track file mappings in RocksDB across local and remote directories - * @param versionToRocksDBFiles Mapping of RocksDB files used across versions for maintenance - * @param localFilesToDfsFiles Mapping of the exact Dfs file used to create a local SST file - * The reason localFilesToDfsFiles is a separate map because versionToRocksDBFiles can contain - * multiple similar SST files to a particular local file (for example 1.sst can map to 1-UUID1.sst - * in v1 and 1-UUID2.sst in v2). We need to capture the exact file used to ensure Version ID - * compatibility across SST files and RocksDB manifest. - */ - -case class RocksDBFileMappings( - versionToRocksDBFiles: ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]], - localFilesToDfsFiles: ConcurrentHashMap[String, RocksDBImmutableFile]) - /** * Metrics regarding RocksDB file sync between local and DFS. */ @@ -1067,7 +1015,10 @@ object RocksDBImmutableFile { val LOG_FILES_DFS_SUBDIR = "logs" val LOG_FILES_LOCAL_SUBDIR = "archive" - def apply(localFileName: String, dfsFileName: String, sizeBytes: Long): RocksDBImmutableFile = { + def apply( + localFileName: String, + dfsFileName: String, + sizeBytes: Long): RocksDBImmutableFile = { if (isSstFile(localFileName)) { RocksDBSstFile(localFileName, dfsFileName, sizeBytes) } else if (isLogFile(localFileName)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 9fcd2001cce50..8fde216c14411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -802,10 +802,12 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared val cpFiles = Seq() generateFiles(verificationDir, cpFiles) assert(!dfsRootDir.exists()) - saveCheckpointFiles(fileManager, cpFiles, version = 1, numKeys = -1) + val fileMapping = new RocksDBFileMapping + saveCheckpointFiles(fileManager, cpFiles, version = 1, + numKeys = -1, fileMapping) // The dfs root dir is created even with unknown number of keys assert(dfsRootDir.exists()) - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 1, Nil, -1) + loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 1, Nil, -1, fileMapping) } finally { Utils.deleteRecursively(dfsRootDir) } @@ -896,6 +898,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared // that checkpoint the same version of state val fileManager = new RocksDBFileManager( dfsRootDir, Utils.createTempDir(), new Configuration) + val rocksDBFileMapping = new RocksDBFileMapping() val fileManager_ = new RocksDBFileManager( dfsRootDir, Utils.createTempDir(), new Configuration) val sstDir = s"$dfsRootDir/SSTs" @@ -912,7 +915,10 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00001.log" -> 1000, "archive/00002.log" -> 2000 ) - saveCheckpointFiles(fileManager, cpFiles1, version = 1, numKeys = 101) + + rocksDBFileMapping.currentVersion = 1 + saveCheckpointFiles(fileManager, cpFiles1, version = 1, + numKeys = 101, rocksDBFileMapping) assert(fileManager.getLatestVersion() === 1) assert(numRemoteSSTFiles == 2) // 2 sst files copied assert(numRemoteLogFiles == 2) @@ -926,7 +932,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00002.log" -> 1000, "archive/00003.log" -> 2000 ) - saveCheckpointFiles(fileManager_, cpFiles1_, version = 1, numKeys = 101) + saveCheckpointFiles(fileManager_, cpFiles1_, version = 1, + numKeys = 101, new RocksDBFileMapping()) assert(fileManager_.getLatestVersion() === 1) assert(numRemoteSSTFiles == 4) assert(numRemoteLogFiles == 4) @@ -945,7 +952,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00004.log" -> 1000, "archive/00005.log" -> 2000 ) - saveCheckpointFiles(fileManager_, cpFiles2, version = 2, numKeys = 121) + saveCheckpointFiles(fileManager_, cpFiles2, + version = 2, numKeys = 121, new RocksDBFileMapping()) fileManager_.deleteOldVersions(1) assert(numRemoteSSTFiles <= 4) // delete files recorded in 1.zip assert(numRemoteLogFiles <= 5) // delete files recorded in 1.zip and orphan 00001.log @@ -959,7 +967,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00006.log" -> 1000, "archive/00007.log" -> 2000 ) - saveCheckpointFiles(fileManager_, cpFiles3, version = 3, numKeys = 131) + saveCheckpointFiles(fileManager_, cpFiles3, + version = 3, numKeys = 131, new RocksDBFileMapping()) assert(fileManager_.getLatestVersion() === 3) fileManager_.deleteOldVersions(1) assert(numRemoteSSTFiles == 1) @@ -996,7 +1005,9 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00001.log" -> 1000, "archive/00002.log" -> 2000 ) - saveCheckpointFiles(fileManager, cpFiles1, version = 1, numKeys = 101) + val rocksDBFileMapping = new RocksDBFileMapping() + saveCheckpointFiles(fileManager, cpFiles1, + version = 1, numKeys = 101, rocksDBFileMapping) fileManager.deleteOldVersions(1) // Should not delete orphan files even when they are older than all existing files // when there is only 1 version. @@ -1013,7 +1024,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00003.log" -> 1000, "archive/00004.log" -> 2000 ) - saveCheckpointFiles(fileManager, cpFiles2, version = 2, numKeys = 101) + saveCheckpointFiles(fileManager, cpFiles2, + version = 2, numKeys = 101, rocksDBFileMapping) assert(numRemoteSSTFiles == 5) assert(numRemoteLogFiles == 5) fileManager.deleteOldVersions(1) @@ -1034,13 +1046,14 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared def numRemoteSSTFiles: Int = listFiles(sstDir).length val logDir = s"$dfsRootDir/logs" def numRemoteLogFiles: Int = listFiles(logDir).length + val fileMapping = new RocksDBFileMapping // Verify behavior before any saved checkpoints assert(fileManager.getLatestVersion() === 0) // Try to load incorrect versions intercept[FileNotFoundException] { - fileManager.loadCheckpointFromDfs(1, Utils.createTempDir()) + fileManager.loadCheckpointFromDfs(1, Utils.createTempDir(), fileMapping) } // Save a version of checkpoint files @@ -1052,7 +1065,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00001.log" -> 1000, "archive/00002.log" -> 2000 ) - saveCheckpointFiles(fileManager, cpFiles1, version = 1, numKeys = 101) + saveCheckpointFiles(fileManager, cpFiles1, + version = 1, numKeys = 101, fileMapping) assert(fileManager.getLatestVersion() === 1) assert(numRemoteSSTFiles == 2) // 2 sst files copied assert(numRemoteLogFiles == 2) // 2 log files copied @@ -1067,12 +1081,16 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "00005.log" -> 101, "archive/00007.log" -> 101 )) - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 1, cpFiles1, 101) + + // as we are loading version 1 again, the previously committed 1,zip and + // SST files would not be reused. + loadAndVerifyCheckpointFiles(fileManager, verificationDir, + version = 1, cpFiles1, 101, fileMapping) // Save SAME version again with different checkpoint files and load back again to verify // whether files were overwritten. val cpFiles1_ = Seq( - "sst-file1.sst" -> 10, // same SST file as before, this should get reused + "sst-file1.sst" -> 10, // same SST file as before, but will be uploaded again "sst-file2.sst" -> 25, // new SST file with same name as before, but different length "sst-file3.sst" -> 30, // new SST file "other-file1" -> 100, // same non-SST file as before, should not get copied @@ -1082,33 +1100,51 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00002.log" -> 2500, // new log file with same name as before, but different length "archive/00003.log" -> 3000 // new log file ) - saveCheckpointFiles(fileManager, cpFiles1_, version = 1, numKeys = 1001) - assert(numRemoteSSTFiles === 4, "shouldn't copy same files again") // 2 old + 2 new SST files - assert(numRemoteLogFiles === 4, "shouldn't copy same files again") // 2 old + 2 new log files - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 1, cpFiles1_, 1001) + + // upload version 1 again, new checkpoint will be created and SST files from + // previously committed version 1 will not be reused. + saveCheckpointFiles(fileManager, cpFiles1_, + version = 1, numKeys = 1001, fileMapping) + assert(numRemoteSSTFiles === 5, "shouldn't reuse old version 1 SST files" + + " while uploading version 1 again") // 2 old + 3 new SST files + assert(numRemoteLogFiles === 5, "shouldn't reuse old version 1 log files" + + " while uploading version 1 again") // 2 old + 3 new log files + + // verify checkpoint state is correct + loadAndVerifyCheckpointFiles(fileManager, verificationDir, + version = 1, cpFiles1_, 1001, fileMapping) // Save another version and verify val cpFiles2 = Seq( - "sst-file4.sst" -> 40, + "sst-file1.sst" -> 10, // same SST file as version 1, should be reused + "sst-file2.sst" -> 25, // same SST file as version 1, should be reused + "sst-file3.sst" -> 30, // same SST file as version 1, should be reused + "sst-file4.sst" -> 40, // new sst file, should be uploaded "other-file4" -> 400, "archive/00004.log" -> 4000 ) - saveCheckpointFiles(fileManager, cpFiles2, version = 2, numKeys = 1501) - assert(numRemoteSSTFiles === 5) // 1 new file over earlier 4 files - assert(numRemoteLogFiles === 5) // 1 new file over earlier 4 files - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 2, cpFiles2, 1501) + + saveCheckpointFiles(fileManager, cpFiles2, + version = 2, numKeys = 1501, fileMapping) + assert(numRemoteSSTFiles === 6) // 1 new file over earlier 5 files + assert(numRemoteLogFiles === 6) // 1 new file over earlier 6 files + loadAndVerifyCheckpointFiles(fileManager, verificationDir, + version = 2, cpFiles2, 1501, fileMapping) // Loading an older version should work - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 1, cpFiles1_, 1001) + loadAndVerifyCheckpointFiles(fileManager, verificationDir, + version = 1, cpFiles1_, 1001, fileMapping) // Loading incorrect version should fail intercept[FileNotFoundException] { - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 3, Nil, 1001) + loadAndVerifyCheckpointFiles(fileManager, verificationDir, + version = 3, Nil, 1001, fileMapping) } // Loading 0 should delete all files require(verificationDir.list().length > 0) - loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 0, Nil, 0) + loadAndVerifyCheckpointFiles(fileManager, verificationDir, + version = 0, Nil, 0, fileMapping) } } @@ -1125,7 +1161,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared val cpFiles = Seq("sst-file1.sst" -> 10, "sst-file2.sst" -> 20, "other-file1" -> 100) CreateAtomicTestManager.shouldFailInCreateAtomic = true intercept[IOException] { - saveCheckpointFiles(fileManager, cpFiles, version = 1, numKeys = 101) + saveCheckpointFiles(fileManager, cpFiles, + version = 1, numKeys = 101, new RocksDBFileMapping()) } assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } @@ -1779,37 +1816,39 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "validate successful RocksDB load when metadata file is not overwritten") { val fmClass = "org.apache.spark.sql.execution.streaming.state." + "NoOverwriteFileSystemBasedCheckpointFileManager" - withTempDir { dir => - val conf = dbConf.copy(minDeltasForSnapshot = 0) // create snapshot every commit - val hadoopConf = new Configuration() - hadoopConf.set(STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, fmClass) - - val remoteDir = dir.getCanonicalPath - withDB(remoteDir, conf = conf, hadoopConf = hadoopConf) { db => - db.load(0) - db.put("a", "1") - db.commit() + Seq(Some(fmClass), None).foreach { fm => + withTempDir { dir => + val conf = dbConf.copy(minDeltasForSnapshot = 0) // create snapshot every commit + val hadoopConf = new Configuration() + fm.foreach(value => + hadoopConf.set(STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, value)) + val remoteDir = dir.getCanonicalPath + withDB(remoteDir, conf = conf, hadoopConf = hadoopConf) { db => + db.load(0) + db.put("a", "1") + db.commit() - // load previous version, and recreate the snapshot - db.load(0) - db.put("a", "1") + // load previous version, will recreate snapshot on commit + db.load(0) + db.put("a", "1") - // do not upload version 1 snapshot created previously - db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) == Seq.empty) + // upload version 1 snapshot created previously + db.doMaintenance() + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) - db.commit() // create snapshot again + db.commit() // create snapshot again - // load version 1 - should succeed - withDB(remoteDir, version = 1, conf = conf, hadoopConf = hadoopConf) { db => - } + // load version 1 - should succeed + withDB(remoteDir, version = 1, conf = conf, hadoopConf = hadoopConf) { db => + } - // upload recently created snapshot - db.doMaintenance() - assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + // upload recently created snapshot + db.doMaintenance() + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) - // load version 1 again - should succeed - withDB(remoteDir, version = 1, conf = conf, hadoopConf = hadoopConf) { db => + // load version 1 again - should succeed + withDB(remoteDir, version = 1, conf = conf, hadoopConf = hadoopConf) { db => + } } } } @@ -2241,14 +2280,20 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared fileManager: RocksDBFileManager, fileToLengths: Seq[(String, Int)], version: Int, - numKeys: Int): Unit = { + numKeys: Int, + fileMapping: RocksDBFileMapping): Unit = { val checkpointDir = Utils.createTempDir().getAbsolutePath // local dir to create checkpoints generateFiles(checkpointDir, fileToLengths) + fileMapping.currentVersion = version - 1 + val (dfsFileSuffix, immutableFileMapping) = fileMapping.createSnapshotFileMapping( + fileManager, checkpointDir, version) fileManager.saveCheckpointToDfs( checkpointDir, version, numKeys, - fileManager.captureFileMapReference()) + immutableFileMapping) + val snapshotInfo = RocksDBVersionSnapshotInfo(version, dfsFileSuffix) + fileMapping.snapshotsPendingUpload.remove(snapshotInfo) } def loadAndVerifyCheckpointFiles( @@ -2256,8 +2301,10 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared verificationDir: String, version: Int, expectedFiles: Seq[(String, Int)], - expectedNumKeys: Int): Unit = { - val metadata = fileManager.loadCheckpointFromDfs(version, verificationDir) + expectedNumKeys: Int, + fileMapping: RocksDBFileMapping): Unit = { + val metadata = fileManager.loadCheckpointFromDfs(version, + verificationDir, fileMapping) val filesAndLengths = listFiles(verificationDir).map(f => f.getName -> f.length).toSet ++ listFiles(verificationDir + "/archive").map(f => s"archive/${f.getName}" -> f.length()).toSet From f96a6f8e3a0b2863ad809b9154acbb1960c7c6e5 Mon Sep 17 00:00:00 2001 From: subham611 Date: Thu, 17 Oct 2024 16:01:43 +0900 Subject: [PATCH 019/108] [SPARK-49259][SS] Size based partition creation during kafka read MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Adds support for size based partition creation during kafka read. ### Why are the changes needed? Currently Spark structured streaming provides `minPartitions` config to create more number of partitions than kafka has. This is helpful to increase parallelism but this value is can not be changed dynamically. It would be better to dynamically increase spark partitions based on input size, if input size is high create more partitions. With this change we can dynamically create more partitions to handle varying loads. ### Does this PR introduce _any_ user-facing change? An additional parameter(maxRecordsPerPartition) will be accepted on the Kafka source provider. Screenshot 2024-10-17 at 11 13 27 AM Screenshot 2024-10-17 at 11 11 51 AM ### How was this patch tested? Added Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #47927 from SubhamSinghal/SPARK-49259-structured-streaming-size-based-partition-creation-kafka. Authored-by: subham611 Signed-off-by: Jungtaek Lim --- .../kafka010/KafkaOffsetRangeCalculator.scala | 133 +++++++++++------- .../sql/kafka010/KafkaOffsetReaderAdmin.scala | 12 +- .../kafka010/KafkaOffsetReaderConsumer.scala | 12 +- .../sql/kafka010/KafkaSourceProvider.scala | 9 ++ .../KafkaOffsetRangeCalculatorSuite.scala | 77 ++++++++++ .../structured-streaming-kafka-integration.md | 14 ++ 6 files changed, 201 insertions(+), 56 deletions(-) diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index 4c0620a35cc21..ae3c50f82e2d5 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -21,22 +21,26 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.util.CaseInsensitiveStringMap - /** - * Class to calculate offset ranges to process based on the from and until offsets, and - * the configured `minPartitions`. + * Class to calculate offset ranges to process based on the from and until offsets, and the + * configured `minPartitions` and `maxRecordsPerPartition`. */ -private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { +private[kafka010] class KafkaOffsetRangeCalculator( + val minPartitions: Option[Int], + val maxRecordsPerPartition: Option[Long]) { require(minPartitions.isEmpty || minPartitions.get > 0) + require(maxRecordsPerPartition.isEmpty || maxRecordsPerPartition.get > 0) /** - * Calculate the offset ranges that we are going to process this batch. If `minPartitions` - * is not set or is set less than or equal the number of `topicPartitions` that we're going to - * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If - * `minPartitions` is set higher than the number of our `topicPartitions`, then we will split up - * the read tasks of the skewed partitions to multiple Spark tasks. - * The number of Spark tasks will be *approximately* `minPartitions`. It can be less or more - * depending on rounding errors or Kafka partitions that didn't receive any new data. + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` is + * not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume and, `maxRecordsPerPartition` is not set then we fall back to a 1-1 mapping of Spark + * tasks to Kafka partitions. If `maxRecordsPerPartition` is set, then we will split up read + * task to multiple tasks as per `maxRecordsPerPartition` value. If `minPartitions` is set + * higher than the number of our `topicPartitions`, then we will split up the read tasks of the + * skewed partitions to multiple Spark tasks. The number of Spark tasks will be *approximately* + * max of `(recordsPerPartition/maxRecordsPerPartition)` and `minPartitions`. It can be less or + * more depending on rounding errors or Kafka partitions that didn't receive any new data. * * Empty (`KafkaOffsetRange.size == 0`) or invalid (`KafkaOffsetRange.size < 0`) ranges will be * dropped. @@ -47,51 +51,81 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int val offsetRanges = ranges.filter(_.size > 0) // If minPartitions not set or there are enough partitions to satisfy minPartitions - if (minPartitions.isEmpty || offsetRanges.size >= minPartitions.get) { + // and maxRecordsPerPartition is empty + if ((minPartitions.isEmpty || offsetRanges.size >= minPartitions.get) + && maxRecordsPerPartition.isEmpty) { // Assign preferred executor locations to each range such that the same topic-partition is // preferentially read from the same executor and the KafkaConsumer can be reused. offsetRanges.map { range => range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) } } else { + val dividedOffsetRanges = if (maxRecordsPerPartition.isDefined) { + val maxRecords = maxRecordsPerPartition.get + offsetRanges + .flatMap { range => + val size = range.size + // number of partitions to divvy up this topic partition to + val parts = math.ceil(size.toDouble / maxRecords).toInt + getDividedPartition(parts, range) + } + .filter(_.size > 0) + } else { + offsetRanges + } - // Splits offset ranges with relatively large amount of data to smaller ones. - val totalSize = offsetRanges.map(_.size).sum + if (minPartitions.isDefined && minPartitions.get > dividedOffsetRanges.size) { + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = dividedOffsetRanges.map(_.size).sum + + // First distinguish between any small (i.e. unsplit) ranges and large (i.e. split) ranges, + // in order to exclude the contents of unsplit ranges from the proportional math applied to + // split ranges + val unsplitRanges = dividedOffsetRanges.filter { range => + getPartCount(range.size, totalSize, minPartitions.get) == 1 + } - // First distinguish between any small (i.e. unsplit) ranges and large (i.e. split) ranges, - // in order to exclude the contents of unsplit ranges from the proportional math applied to - // split ranges - val unsplitRanges = offsetRanges.filter { range => - getPartCount(range.size, totalSize, minPartitions.get) == 1 + val unsplitRangeTotalSize = unsplitRanges.map(_.size).sum + val splitRangeTotalSize = totalSize - unsplitRangeTotalSize + val unsplitRangeTopicPartitions = unsplitRanges.map(_.topicPartition).toSet + val splitRangeMinPartitions = math.max(minPartitions.get - unsplitRanges.size, 1) + + // Now we can apply the main calculation logic + dividedOffsetRanges + .flatMap { range => + val tp = range.topicPartition + val size = range.size + // number of partitions to divvy up this topic partition to + val parts = if (unsplitRangeTopicPartitions.contains(tp)) { + 1 + } else { + getPartCount(size, splitRangeTotalSize, splitRangeMinPartitions) + } + getDividedPartition(parts, range) + } + .filter(_.size > 0) + } else { + dividedOffsetRanges } + } + } - val unsplitRangeTotalSize = unsplitRanges.map(_.size).sum - val splitRangeTotalSize = totalSize - unsplitRangeTotalSize - val unsplitRangeTopicPartitions = unsplitRanges.map(_.topicPartition).toSet - val splitRangeMinPartitions = math.max(minPartitions.get - unsplitRanges.size, 1) - - // Now we can apply the main calculation logic - offsetRanges.flatMap { range => - val tp = range.topicPartition - val size = range.size - // number of partitions to divvy up this topic partition to - val parts = if (unsplitRangeTopicPartitions.contains(tp)) { - 1 - } else { - getPartCount(size, splitRangeTotalSize, splitRangeMinPartitions) - } - var remaining = size - var startOffset = range.fromOffset - (0 until parts).map { part => - // Fine to do integer division. Last partition will consume all the round off errors - val thisPartition = remaining / (parts - part) - remaining -= thisPartition - val endOffset = math.min(startOffset + thisPartition, range.untilOffset) - val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None) - startOffset = endOffset - offsetRange - } - }.filter(_.size > 0) + private def getDividedPartition( + parts: Int, + offsetRange: KafkaOffsetRange): IndexedSeq[KafkaOffsetRange] = { + var remaining = offsetRange.size + var startOffset = offsetRange.fromOffset + val tp = offsetRange.topicPartition + val untilOffset = offsetRange.untilOffset + + (0 until parts).map { part => + // Fine to do integer division. Last partition will consume all the round off errors + val thisPartition = remaining / (parts - part) + remaining -= thisPartition + val endOffset = math.min(startOffset + thisPartition, untilOffset) + val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None) + startOffset = endOffset + offsetRange } } @@ -114,9 +148,12 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int private[kafka010] object KafkaOffsetRangeCalculator { def apply(options: CaseInsensitiveStringMap): KafkaOffsetRangeCalculator = { - val optionalValue = Option(options.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) + val minPartition = Option(options.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) .map(_.toInt) - new KafkaOffsetRangeCalculator(optionalValue) + val maxRecordsPerPartition = + Option(options.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY)) + .map(_.toLong) + new KafkaOffsetRangeCalculator(minPartition, maxRecordsPerPartition) } } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala index bb4f14686f976..0bdd931028aef 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala @@ -99,14 +99,18 @@ private[kafka010] class KafkaOffsetReaderAdmin( */ private val minPartitions = readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt) + private val maxRecordsPerPartition = + readerOptions.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY).map(_.toLong) - private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions) + private val rangeCalculator = + new KafkaOffsetRangeCalculator(minPartitions, maxRecordsPerPartition) /** * Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks. */ - private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = { - minPartitions.map(_ > numTopicPartitions).getOrElse(false) + private def shouldDivvyUpLargePartitions(offsetRanges: Seq[KafkaOffsetRange]): Boolean = { + minPartitions.map(_ > offsetRanges.size).getOrElse(false) || + offsetRanges.exists(_.size > maxRecordsPerPartition.getOrElse(Long.MaxValue)) } override def toString(): String = consumerStrategy.toString @@ -397,7 +401,7 @@ private[kafka010] class KafkaOffsetReaderAdmin( KafkaOffsetRange(tp, fromOffset, untilOffset, None) }.toSeq - if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) { + if (shouldDivvyUpLargePartitions(offsetRangesBase)) { val fromOffsetsMap = offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap val untilOffsetsMap = diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala index fa53d6373176e..f7530dcba6b85 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala @@ -98,8 +98,11 @@ private[kafka010] class KafkaOffsetReaderConsumer( */ private val minPartitions = readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt) + private val maxRecordsPerPartition = + readerOptions.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY).map(_.toLong) - private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions) + private val rangeCalculator = + new KafkaOffsetRangeCalculator(minPartitions, maxRecordsPerPartition) private[kafka010] val offsetFetchAttemptIntervalMs = readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong @@ -107,8 +110,9 @@ private[kafka010] class KafkaOffsetReaderConsumer( /** * Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks. */ - private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = { - minPartitions.map(_ > numTopicPartitions).getOrElse(false) + private def shouldDivvyUpLargePartitions(offsetRanges: Seq[KafkaOffsetRange]): Boolean = { + minPartitions.map(_ > offsetRanges.size).getOrElse(false) || + offsetRanges.exists(_.size > maxRecordsPerPartition.getOrElse(Long.MaxValue)) } private def nextGroupId(): String = { @@ -446,7 +450,7 @@ private[kafka010] class KafkaOffsetReaderConsumer( KafkaOffsetRange(tp, fromOffset, untilOffset, None) }.toSeq - if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) { + if (shouldDivvyUpLargePartitions(offsetRangesBase)) { val fromOffsetsMap = offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap val untilOffsetsMap = diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index e1fdbfb183c39..4cb9fa8df8052 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -271,6 +271,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") } + if (params.contains(MAX_RECORDS_PER_PARTITION_OPTION_KEY)) { + val p = params(MAX_RECORDS_PER_PARTITION_OPTION_KEY).toLong + if (p <= 0) { + throw new IllegalArgumentException( + s"$MAX_RECORDS_PER_PARTITION_OPTION_KEY must be positive") + } + } + // Validate user-specified Kafka options if (params.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -557,6 +565,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val ENDING_TIMESTAMP_OPTION_KEY = "endingtimestamp" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions" + private[kafka010] val MAX_RECORDS_PER_PARTITION_OPTION_KEY = "maxrecordsperpartition" private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger" private[kafka010] val MIN_OFFSET_PER_TRIGGER = "minoffsetspertrigger" private[kafka010] val MAX_TRIGGER_DELAY = "maxtriggerdelay" diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala index 89ab0902f4d6f..516aee6ad537d 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -34,6 +34,30 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { } } + def testWithMaxRecordsPerPartition(name: String, maxRecordsPerPartition: Long)( + f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new CaseInsensitiveStringMap( + Map("maxRecordsPerPartition" -> maxRecordsPerPartition.toString).asJava) + test(s"with maxRecordsPerPartition = $maxRecordsPerPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + def testWithMinPartitionsAndMaxRecordsPerPartition( + name: String, + minPartitions: Int, + maxRecordsPerPartition: Long)(f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new CaseInsensitiveStringMap( + Map( + "minPartitions" -> minPartitions.toString, + "maxRecordsPerPartition" -> maxRecordsPerPartition.toString).asJava) + test( + s"with minPartitions = $minPartitions " + + s"and maxRecordsPerPartition = $maxRecordsPerPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + test("with no minPartition: N TopicPartitions to N offset ranges") { val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty()) assert( @@ -253,6 +277,59 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { KafkaOffsetRange(tp3, 7500, 10000, None))) } + testWithMaxRecordsPerPartition("SPARK-49259: 1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 5))) == Seq(KafkaOffsetRange(tp1, 1, 5, None))) + + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 2))) == Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 6)), executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 3, None), KafkaOffsetRange(tp1, 3, 6, None)) + ) // location pref not set when maxRecordsPerPartition is set + } + + testWithMaxRecordsPerPartition("SPARK-49259: N TopicPartition to N offset ranges", 20) { calc => + assert( + calc.getRanges( + Seq( + KafkaOffsetRange(tp1, 1, 40), + KafkaOffsetRange(tp2, 1, 50), + KafkaOffsetRange(tp3, 1, 60))) == + Seq( + KafkaOffsetRange(tp1, 1, 20, None), + KafkaOffsetRange(tp1, 20, 40, None), + KafkaOffsetRange(tp2, 1, 17, None), + KafkaOffsetRange(tp2, 17, 33, None), + KafkaOffsetRange(tp2, 33, 50, None), + KafkaOffsetRange(tp3, 1, 20, None), + KafkaOffsetRange(tp3, 20, 40, None), + KafkaOffsetRange(tp3, 40, 60, None))) + } + + testWithMinPartitionsAndMaxRecordsPerPartition( + "SPARK-49259: 1 TopicPartition with low minPartitions value", + 1, + 20) { calc => + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 40))) == + Seq(KafkaOffsetRange(tp1, 1, 20, None), KafkaOffsetRange(tp1, 20, 40, None))) + } + + testWithMinPartitionsAndMaxRecordsPerPartition( + "SPARK-49259: 1 TopicPartition with high minPartitions value", + 4, + 20) { calc => + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 40))) == + Seq( + KafkaOffsetRange(tp1, 1, 10, None), + KafkaOffsetRange(tp1, 10, 20, None), + KafkaOffsetRange(tp1, 20, 30, None), + KafkaOffsetRange(tp1, 30, 40, None))) + } + private val tp1 = new TopicPartition("t1", 1) private val tp2 = new TopicPartition("t2", 1) private val tp3 = new TopicPartition("t3", 1) diff --git a/docs/streaming/structured-streaming-kafka-integration.md b/docs/streaming/structured-streaming-kafka-integration.md index 37846216fc758..a8f2bcdeb9bc1 100644 --- a/docs/streaming/structured-streaming-kafka-integration.md +++ b/docs/streaming/structured-streaming-kafka-integration.md @@ -518,6 +518,20 @@ The following configurations are optional: number of Spark tasks will be approximately minPartitions. It can be less or more depending on rounding errors or Kafka partitions that didn't receive any new data. + + maxRecordsPerPartition + long + none + streaming and batch + Limit maximum number of records present in a partition. + By default, Spark has a 1-1 mapping of topicPartitions to Spark partitions consuming from Kafka. + If you set this option, Spark will divvy up Kafka partitions to smaller pieces so that each partition + has upto maxRecordsPerPartition records. When both minPartitions and + maxRecordsPerPartition are set, number of partitions will be approximately + max of (recordsPerPartition / maxRecordsPerPartition) and minPartitions. In such case spark + will divvy up partitions based on maxRecordsPerPartition and if the final partition count is less than + minPartitions it will divvy up partitions again based on minPartitions. + groupIdPrefix string From 2469a6f29c0e6f59e928f4f4e7aa9709768d4757 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 17 Oct 2024 16:49:00 +0900 Subject: [PATCH 020/108] [SPARK-49495][INFRA][FOLLOW-UP] Disable GitHub Pages workflow in forked repository ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/commit/4f640e2485d24088345b3f2d894c696ef29e2923 that disables GitHub Pages workflow in forked repository ### Why are the changes needed? To automatically disable GitHub packages workflow in developers' forked repository. We can manually disable them too but this is a bit easier. ### Does this PR introduce _any_ user-facing change? No, dev-only. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48515 from HyukjinKwon/SPARK-49495-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .github/workflows/pages.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index f78f7895a183f..8729012c2b8d2 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -40,6 +40,7 @@ jobs: env: SPARK_TESTING: 1 # Reduce some noise in the logs RELEASE_VERSION: 'In-Progress' + if: github.repository == 'apache/spark' steps: - name: Checkout Spark repository uses: actions/checkout@v4 From 371a09a292efe126b71f4086bb2d6378be6a6a9f Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 10:47:22 +0200 Subject: [PATCH 021/108] [SPARK-49998][SQL] Integrate `_LEGACY_ERROR_TEMP_1252` into `EXPECT_TABLE_NOT_VIEW` ### What changes were proposed in this pull request? This PR proposes to Integrate `_LEGACY_ERROR_TEMP_1252` into `EXPECT_TABLE_NOT_VIEW` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48510 from itholic/SPARK-49998. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 5 ----- .../spark/sql/errors/QueryCompilationErrors.scala | 10 +++++++--- .../org/apache/spark/sql/execution/command/ddl.scala | 4 +++- .../apache/spark/sql/hive/execution/HiveDDLSuite.scala | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3e4848658f14a..ccf5d123fb1e7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6520,11 +6520,6 @@ " is not allowed on since its partition metadata is not stored in the Hive metastore. To import this information into the metastore, run `msck repair table `." ] }, - "_LEGACY_ERROR_TEMP_1252" : { - "message" : [ - "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead." - ] - }, "_LEGACY_ERROR_TEMP_1253" : { "message" : [ "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 431983214c482..40da52fa7c3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2823,10 +2823,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "tableName" -> tableName)) } - def cannotAlterViewWithAlterTableError(): Throwable = { + def cannotAlterViewWithAlterTableError(viewName: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1252", - messageParameters = Map.empty) + errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + messageParameters = Map( + "operation" -> "ALTER TABLE", + "viewName" -> toSQLId(viewName) + ) + ) } def cannotAlterTableWithAlterViewError(): Throwable = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 814e56b204f9e..0b3469d3eb52d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -1008,7 +1008,9 @@ object DDLUtils extends Logging { if (!catalog.isTempView(tableMetadata.identifier)) { tableMetadata.tableType match { case CatalogTableType.VIEW if !isView => - throw QueryCompilationErrors.cannotAlterViewWithAlterTableError() + throw QueryCompilationErrors.cannotAlterViewWithAlterTableError( + viewName = tableMetadata.identifier.table + ) case o if o != CatalogTableType.VIEW && isView => throw QueryCompilationErrors.cannotAlterTableWithAlterViewError() case _ => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 69d54a746b55d..94501d4c1c087 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -843,8 +843,8 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName RENAME TO $newViewName") }, - condition = "_LEGACY_ERROR_TEMP_1252", - parameters = Map.empty + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + parameters = Map("operation" -> "ALTER TABLE", "viewName" -> "`view1`") ) checkError( From 0b7cc3709e6bafac3cb2a2add04771d566891575 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 11:03:07 +0200 Subject: [PATCH 022/108] [SPARK-49971][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_1097 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for` _LEGACY_ERROR_TEMP_1097` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48471 from itholic/LEGACY_1097. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 11 ++++++----- .../spark/sql/catalyst/expressions/ExprUtils.scala | 3 ++- .../spark/sql/errors/QueryCompilationErrors.scala | 7 ++++--- .../catalyst/expressions/CsvExpressionsSuite.scala | 6 +++--- .../catalyst/expressions/JsonExpressionsSuite.scala | 6 +++--- .../catalyst/expressions/XmlExpressionsSuite.scala | 8 ++++---- .../sql/execution/datasources/csv/CSVSuite.scala | 5 +++-- .../sql/execution/datasources/json/JsonSuite.scala | 10 ++++++---- 8 files changed, 31 insertions(+), 25 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ccf5d123fb1e7..2f30fa2c70f6c 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2150,6 +2150,12 @@ }, "sqlState" : "22022" }, + "INVALID_CORRUPT_RECORD_TYPE" : { + "message" : [ + "The column for corrupt records must have the nullable STRING type, but got ." + ], + "sqlState" : "42804" + }, "INVALID_CURSOR" : { "message" : [ "The cursor is invalid." @@ -6041,11 +6047,6 @@ "Column statistics serialization is not supported for column of data type: ." ] }, - "_LEGACY_ERROR_TEMP_1097" : { - "message" : [ - "The field for corrupt records must be string type and nullable." - ] - }, "_LEGACY_ERROR_TEMP_1098" : { "message" : [ "DataType '' is not supported by ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 38b927f5bbf38..9c617b51df62f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -83,7 +83,8 @@ object ExprUtils extends EvalHelper with QueryErrorsBase { schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = schema(corruptFieldIndex) if (!f.dataType.isInstanceOf[StringType] || !f.nullable) { - throw QueryCompilationErrors.invalidFieldTypeForCorruptRecordError() + throw QueryCompilationErrors.invalidFieldTypeForCorruptRecordError( + columnNameOfCorruptRecord, f.dataType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 40da52fa7c3d8..5e08d463b9e07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1327,10 +1327,11 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map.empty) } - def invalidFieldTypeForCorruptRecordError(): Throwable = { + def invalidFieldTypeForCorruptRecordError(columnName: String, actualType: DataType): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1097", - messageParameters = Map.empty) + errorClass = "INVALID_CORRUPT_RECORD_TYPE", + messageParameters = Map( + "columnName" -> toSQLId(columnName), "actualType" -> toSQLType(actualType))) } def dataTypeUnsupportedByClassError(x: DataType, className: String): Throwable = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 249975f9c0d4c..81dd8242c600b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -233,13 +233,13 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P } test("verify corrupt column") { - checkExceptionInExpression[AnalysisException]( + checkErrorInExpression[AnalysisException]( CsvToStructs( schema = StructType.fromDDL("i int, _unparsed boolean"), options = Map("columnNameOfCorruptRecord" -> "_unparsed"), child = Literal.create("a"), - timeZoneId = UTC_OPT), - expectedErrMsg = "The field for corrupt records must be string type and nullable") + timeZoneId = UTC_OPT), null, "INVALID_CORRUPT_RECORD_TYPE", + Map("columnName" -> "`_unparsed`", "actualType" -> "\"BOOLEAN\"")) } test("from/to csv with intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index adb39fcd568c9..0afaf4ec097c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -791,13 +791,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with } test("verify corrupt column") { - checkExceptionInExpression[AnalysisException]( + checkErrorInExpression[AnalysisException]( JsonToStructs( schema = StructType.fromDDL("i int, _unparsed boolean"), options = Map("columnNameOfCorruptRecord" -> "_unparsed"), child = Literal.create("""{"i":"a"}"""), - timeZoneId = UTC_OPT), - expectedErrMsg = "The field for corrupt records must be string type and nullable") + timeZoneId = UTC_OPT), null, "INVALID_CORRUPT_RECORD_TYPE", + Map("columnName" -> "`_unparsed`", "actualType" -> "\"BOOLEAN\"")) } def decimalInput(langTag: String): (Decimal, String) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala index 66baf6b1430fa..4f38cd0630f2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala @@ -391,13 +391,13 @@ class XmlExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P } test("verify corrupt column") { - checkExceptionInExpression[AnalysisException]( - XmlToStructs( + checkErrorInExpression[AnalysisException]( + JsonToStructs( schema = StructType.fromDDL("i int, _unparsed boolean"), options = Map("columnNameOfCorruptRecord" -> "_unparsed"), child = Literal.create("""{"i":"a"}"""), - timeZoneId = UTC_OPT), - expectedErrMsg = "The field for corrupt records must be string type and nullable") + timeZoneId = UTC_OPT), null, "INVALID_CORRUPT_RECORD_TYPE", + Map("columnName" -> "`_unparsed`", "actualType" -> "\"BOOLEAN\"")) } def decimalInput(langTag: String): (Decimal, String) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 422ae02a18322..7cacd8ea2dc50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1510,8 +1510,9 @@ abstract class CSVSuite .csv(testFile(valueMalformedFile)) .collect() }, - condition = "_LEGACY_ERROR_TEMP_1097", - parameters = Map.empty + condition = "INVALID_CORRUPT_RECORD_TYPE", + parameters = Map( + "columnName" -> toSQLId(columnNameOfCorruptRecord), "actualType" -> "\"INT\"") ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index aea95f0af117a..a5b5fe3bb5d41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2117,8 +2117,9 @@ abstract class JsonSuite .schema(schema) .json(corruptRecords) }, - condition = "_LEGACY_ERROR_TEMP_1097", - parameters = Map.empty + condition = "INVALID_CORRUPT_RECORD_TYPE", + parameters = Map( + "columnName" -> toSQLId(columnNameOfCorruptRecord), "actualType" -> "\"INT\"") ) // We use `PERMISSIVE` mode by default if invalid string is given. @@ -2134,8 +2135,9 @@ abstract class JsonSuite .json(path) .collect() }, - condition = "_LEGACY_ERROR_TEMP_1097", - parameters = Map.empty + condition = "INVALID_CORRUPT_RECORD_TYPE", + parameters = Map( + "columnName" -> toSQLId(columnNameOfCorruptRecord), "actualType" -> "\"INT\"") ) } } From 0dd6a2a464418507d04df8c6b8d58795aa381ea5 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 11:13:29 +0200 Subject: [PATCH 023/108] [SPARK-49997][SQL] Integrate `_LEGACY_ERROR_TEMP_2165` into `MALFORMED_RECORD_IN_PARSING` ### What changes were proposed in this pull request? This PR proposes to Integrate `_LEGACY_ERROR_TEMP_2165` into `MALFORMED_RECORD_IN_PARSING` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48508 from itholic/SPARK-49997. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../main/resources/error/error-conditions.json | 5 ----- .../sql/catalyst/json/JsonInferSchema.scala | 3 ++- .../spark/sql/catalyst/xml/XmlInferSchema.scala | 3 ++- .../spark/sql/errors/QueryExecutionErrors.scala | 5 +++-- .../execution/datasources/json/JsonSuite.scala | 6 ++++-- .../execution/datasources/xml/XmlSuite.scala | 17 +++++++++++------ 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 2f30fa2c70f6c..fdf3cf7ccbeb3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -7258,11 +7258,6 @@ "Initial type must be an , a or a ." ] }, - "_LEGACY_ERROR_TEMP_2165" : { - "message" : [ - "Malformed records are detected in schema inference. Parse Mode: ." - ] - }, "_LEGACY_ERROR_TEMP_2166" : { "message" : [ "Malformed JSON." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 9c291634401ee..b509c55ed6a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -69,7 +69,8 @@ class JsonInferSchema(options: JSONOptions) extends Serializable with Logging { case DropMalformedMode => None case FailFastMode => - throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError(e) + throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError( + e, columnNameOfCorruptRecord) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 4640f86d5997a..848e6ff45c5a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -79,7 +79,8 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case DropMalformedMode => None case FailFastMode => - throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError(e) + throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError( + e, columnNameOfCorruptRecord) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index edc1b909292df..26ed25ba90167 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1398,10 +1398,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "mapType" -> MapType.simpleString)) } - def malformedRecordsDetectedInSchemaInferenceError(e: Throwable): Throwable = { + def malformedRecordsDetectedInSchemaInferenceError(e: Throwable, badRecord: String): Throwable = { new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2165", + errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", messageParameters = Map( + "badRecord" -> badRecord, "failFastMode" -> FailFastMode.name), cause = e) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a5b5fe3bb5d41..06183596a54ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1072,8 +1072,10 @@ abstract class JsonSuite .option("mode", "FAILFAST") .json(corruptRecords) }, - condition = "_LEGACY_ERROR_TEMP_2165", - parameters = Map("failFastMode" -> "FAILFAST") + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map( + "badRecord" -> "_corrupt_record", + "failFastMode" -> "FAILFAST") ) checkError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 059e4aadef2bd..fe910c21cb0c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -255,8 +255,10 @@ class XmlSuite .option("mode", FailFastMode.name) .xml(inputFile) }, - condition = "_LEGACY_ERROR_TEMP_2165", - parameters = Map("failFastMode" -> "FAILFAST") + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map( + "badRecord" -> "_corrupt_record", + "failFastMode" -> "FAILFAST") ) val exceptionInParsing = intercept[SparkException] { spark.read @@ -288,8 +290,10 @@ class XmlSuite .option("mode", FailFastMode.name) .xml(inputFile) }, - condition = "_LEGACY_ERROR_TEMP_2165", - parameters = Map("failFastMode" -> "FAILFAST")) + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map( + "badRecord" -> "_corrupt_record", + "failFastMode" -> "FAILFAST")) val exceptionInParsing = intercept[SparkException] { spark.read .schema("_id string") @@ -1328,9 +1332,10 @@ class XmlSuite spark.sql(s"""SELECT schema_of_xml('1', map('mode', 'FAILFAST'))""") .collect() }, - condition = "_LEGACY_ERROR_TEMP_2165", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( - "failFastMode" -> FailFastMode.name) + "badRecord" -> "_corrupt_record", + "failFastMode" -> "FAILFAST") ) } From b078c0d6e2adf7eb0ee7d4742a6c52864440226e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 20:24:09 +0800 Subject: [PATCH 024/108] [SPARK-50009][INFRA] Exclude pandas/resource/testing from `pyspark-core` module ### What changes were proposed in this pull request? Exclude pandas/resource/testing from `pyspark-core` module ### Why are the changes needed? avoid unnecessary tests, e.g. in https://github.com/apache/spark/pull/48516, a pyspark-pandas only change trigger `spark-core` and then all the pyspark tests. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually check like: ``` In [6]: re.match("python/(?!pyspark/(ml|mllib|sql|streaming))", "python/pyspark/pandas/plots") Out[6]: In [7]: re.match("python/(?!pyspark/(ml|mllib|sql|streaming|pandas))", "python/pyspark/pandas/plots") ``` ### Was this patch authored or co-authored using generative AI tooling? no Closes #48518 from zhengruifeng/infra_pyspark_core. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- dev/sparktestsupport/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d2c000b702a64..92b7d9aa25c07 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -427,7 +427,7 @@ def __hash__(self): pyspark_core = Module( name="pyspark-core", dependencies=[core], - source_file_regexes=["python/(?!pyspark/(ml|mllib|sql|streaming))"], + source_file_regexes=["python/(?!pyspark/(ml|mllib|sql|streaming|pandas|resource|testing))"], python_test_goals=[ # doctests "pyspark.conf", From f1f04b097d2ff8850426640c77dd43a53f449ad6 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Thu, 17 Oct 2024 07:27:40 -0700 Subject: [PATCH 025/108] [SPARK-49988][BUILD] Remove unused Hadoop dependency management ### What changes were proposed in this pull request? Remove unused vanilla hadoop dependency(and transitive deps) management, i.e. `hadoop-client`, `xerces:xercesImpl`, and inline deps defined in `hadoop3` because it's the only supported hadoop profile. ### Why are the changes needed? Simplify pom.xml. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass CI and verified runtime jars are not affected by running`dev/test-dependencies.sh`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48491 from pan3793/SPARK-49988. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- LICENSE-binary | 1 - NOTICE-binary | 21 -- pom.xml | 297 ------------------ project/SparkBuild.scala | 1 - .../kubernetes/integration-tests/pom.xml | 20 +- 5 files changed, 6 insertions(+), 334 deletions(-) diff --git a/LICENSE-binary b/LICENSE-binary index 89826482d363a..40d28fbe71e6b 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -402,7 +402,6 @@ org.xerial.snappy:snappy-java org.yaml:snakeyaml oro:oro stax:stax-api -xerces:xercesImpl core/src/main/java/org/apache/spark/util/collection/TimSort.java core/src/main/resources/org/apache/spark/ui/static/bootstrap* diff --git a/NOTICE-binary b/NOTICE-binary index c4cfe0e9f8b31..3f36596b9d6d6 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -448,27 +448,6 @@ which has the following notices: * Alec Wysoker * Performance and memory usage improvement -The binary distribution of this product bundles binaries of -Xerces2 Java Parser 2.9.1, -which has the following notices: - * ========================================================================= - == NOTICE file corresponding to section 4(d) of the Apache License, == - == Version 2.0, in this case for the Apache Xerces Java distribution. == - ========================================================================= - - Apache Xerces Java - Copyright 1999-2007 The Apache Software Foundation - - This product includes software developed at - The Apache Software Foundation (http://www.apache.org/). - - Portions of this software were originally based on the following: - - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. - - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. - - voluntary contributions made by Paul Eng on behalf of the - Apache Software Foundation that were originally developed at iClick, Inc., - software copyright (c) 1999. - Apache Commons Collections Copyright 2001-2015 The Apache Software Foundation diff --git a/pom.xml b/pom.xml index ff15f200e2bb1..2e169df7201c2 100644 --- a/pom.xml +++ b/pom.xml @@ -1418,92 +1418,6 @@ test - - org.apache.hadoop - hadoop-client - ${hadoop.version} - ${hadoop.deps.scope} - - - org.fusesource.leveldbjni - leveldbjni-all - - - asm - asm - - - org.ow2.asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - - commons-beanutils - commons-beanutils-core - - - commons-logging - commons-logging - - - org.mockito - mockito-all - - - org.mortbay.jetty - servlet-api-2.5 - - - javax.servlet - servlet-api - - - junit - junit - - - com.sun.jersey - * - - - com.sun.jersey.jersey-test-framework - * - - - com.sun.jersey.contribs - * - - - net.java.dev.jets3t - jets3t - - - - javax.ws.rs - jsr311-api - - - org.eclipse.jetty - jetty-webapp - - - log4j - log4j - - - org.slf4j - slf4j-log4j12 - - - org.apache.hadoop hadoop-minikdc @@ -1544,16 +1458,6 @@ ${bouncycastle.version} test - - - - xerces - xercesImpl - 2.12.2 - org.apache.avro avro @@ -1636,207 +1540,6 @@ 1.1.1 ${hadoop.deps.scope} - - org.apache.hadoop - hadoop-yarn-api - ${yarn.version} - ${hadoop.deps.scope} - - - javax.servlet - servlet-api - - - asm - asm - - - org.ow2.asm - asm - - - org.jboss.netty - netty - - - commons-logging - commons-logging - - - com.sun.jersey - * - - - com.sun.jersey.jersey-test-framework - * - - - com.sun.jersey.contribs - * - - - jdk.tools - jdk.tools - - - - - org.apache.hadoop - hadoop-yarn-common - ${yarn.version} - ${hadoop.deps.scope} - - - asm - asm - - - org.ow2.asm - asm - - - org.jboss.netty - netty - - - javax.servlet - servlet-api - - - commons-logging - commons-logging - - - com.sun.jersey - * - - - com.sun.jersey.jersey-test-framework - * - - - com.sun.jersey.contribs - * - - - log4j - log4j - - - org.slf4j - slf4j-log4j12 - - - - - org.apache.hadoop - hadoop-yarn-server-tests - ${yarn.version} - tests - test - - - org.fusesource.leveldbjni - leveldbjni-all - - - asm - asm - - - org.ow2.asm - asm - - - org.jboss.netty - netty - - - javax.servlet - servlet-api - - - commons-logging - commons-logging - - - com.sun.jersey - * - - - com.sun.jersey.jersey-test-framework - * - - - com.sun.jersey.contribs - * - - - org.apache.hadoop - hadoop-yarn-server-resourcemanager - - - - - - org.apache.hadoop - hadoop-yarn-server-resourcemanager - ${yarn.version} - test - - - org.apache.hadoop - hadoop-yarn-client - ${yarn.version} - ${hadoop.deps.scope} - - - asm - asm - - - org.ow2.asm - asm - - - org.jboss.netty - netty - - - javax.servlet - servlet-api - - - commons-logging - commons-logging - - - com.sun.jersey - * - - - com.sun.jersey.jersey-test-framework - * - - - com.sun.jersey.contribs - * - - - log4j - log4j - - - org.slf4j - slf4j-log4j12 - - - org.apache.zookeeper zookeeper diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a87e0af0b542f..e7f7d68e98483 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1060,7 +1060,6 @@ object DependencyOverrides { lazy val guavaVersion = sys.props.get("guava.version").getOrElse("33.1.0-jre") lazy val settings = Seq( dependencyOverrides += "com.google.guava" % "guava" % guavaVersion, - dependencyOverrides += "xerces" % "xercesImpl" % "2.12.2", dependencyOverrides += "jline" % "jline" % "2.14.6", dependencyOverrides += "org.apache.avro" % "avro" % "1.11.3") } diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 45ce25b8e037a..cebef07821f39 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -79,6 +79,12 @@ ${project.version} test + + software.amazon.awssdk + bundle + ${aws.java.sdk.v2.version} + test + @@ -189,20 +195,6 @@ - - hadoop-3 - - true - - - - software.amazon.awssdk - bundle - ${aws.java.sdk.v2.version} - test - - - volcano From 360ce0a82bc9675982ef77dd24310a4432e74b62 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 07:31:54 -0700 Subject: [PATCH 026/108] [SPARK-50011][INFRA] Add a separate docker file for doc build ### What changes were proposed in this pull request? Add a separate docker file for doc build ### Why are the changes needed? currently we only have single test image, for `pyspark`, `sparkr`, `lint` and `docs`, it has two major issues: 1, disk space limitation: we are adding more and more packages in it, the disk space left for testing is very limited, and cause `No space left on device` from time to time; 2, environment conflicts: for example, even though we already install some packages for `docs` in the docker file, we still need to install some additional python packages in `build_and_test`, due to the conflicts between `docs` and `pyspark`. It is hard to maintain because the related packages are installed in two different places. so I am thinking of spinning off some installations (e.g. `docs`) from the base image, so that: 1, we can completely cache all the dependencies for `docs`; 2, the related installations are centralized; 3, we can free up disk space on the base image (after we spin off other dependency, we can remove unneeded packages from it); Furthermore, if we want to apply multiple images, we can easily support different environments, e.g. adding a separate image for old versions of `pandas/pyarrow/etc`. ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48520 from zhengruifeng/infra_multiple_docker_file. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- .github/workflows/build_and_test.yml | 38 +++++++----- dev/infra/{ => base}/Dockerfile | 0 dev/infra/docs/Dockerfile | 91 ++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 14 deletions(-) rename dev/infra/{ => base}/Dockerfile (100%) create mode 100644 dev/infra/docs/Dockerfile diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 14d93a498fc59..eadcaaedc5829 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -58,6 +58,7 @@ jobs: outputs: required: ${{ steps.set-outputs.outputs.required }} image_url: ${{ steps.infra-image-outputs.outputs.image_url }} + image_docs_url: ${{ steps.infra-image-docs-outputs.outputs.image_docs_url }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -134,6 +135,14 @@ jobs: IMG_NAME="apache-spark-ci-image:${{ inputs.branch }}-${{ github.run_id }}" IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" echo "image_url=$IMG_URL" >> $GITHUB_OUTPUT + - name: Generate infra image URL (Documentation) + id: infra-image-docs-outputs + run: | + # Convert to lowercase to meet Docker repo name requirement + REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]') + IMG_NAME="apache-spark-ci-image-docs:${{ inputs.branch }}-${{ github.run_id }}" + IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" + echo "image_docs_url=$IMG_URL" >> $GITHUB_OUTPUT # Build: build Spark and run the tests for specified modules. build: @@ -345,12 +354,23 @@ jobs: id: docker_build uses: docker/build-push-action@v6 with: - context: ./dev/infra/ + context: ./dev/infra/base/ push: true tags: | ${{ needs.precondition.outputs.image_url }} # Use the infra image cache to speed up cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ inputs.branch }} + - name: Build and push (Documentation) + id: docker_build_docs + uses: docker/build-push-action@v6 + with: + context: ./dev/infra/docs/ + push: true + tags: | + ${{ needs.precondition.outputs.image_docs_url }} + # Use the infra image cache to speed up + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ inputs.branch }} + pyspark: needs: [precondition, infra-image] @@ -783,7 +803,7 @@ jobs: PYSPARK_PYTHON: python3.9 GITHUB_PREV_SHA: ${{ github.event.before }} container: - image: ${{ needs.precondition.outputs.image_url }} + image: ${{ needs.precondition.outputs.image_docs_url }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -833,18 +853,8 @@ jobs: with: distribution: zulu java-version: ${{ inputs.java }} - - name: Install Python dependencies for python linter and documentation generation - if: inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5' - run: | - # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 - # See 'ipython_genutils' in SPARK-38517 - # See 'docutils<0.18.0' in SPARK-39421 - python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ - 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ - 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ - 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' - python3.9 -m pip list + - name: List Python packages + run: python3.9 -m pip list - name: Install dependencies for documentation generation for branch-3.4, branch-3.5 if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' run: | diff --git a/dev/infra/Dockerfile b/dev/infra/base/Dockerfile similarity index 100% rename from dev/infra/Dockerfile rename to dev/infra/base/Dockerfile diff --git a/dev/infra/docs/Dockerfile b/dev/infra/docs/Dockerfile new file mode 100644 index 0000000000000..8a8e1680182c5 --- /dev/null +++ b/dev/infra/docs/Dockerfile @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building and testing Spark branches. Based on Ubuntu 22.04. +# See also in https://hub.docker.com/_/ubuntu +FROM ubuntu:jammy-20240227 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for Documentation" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" + +ENV FULL_REFRESH_DATE 20241016 + +ENV DEBIAN_FRONTEND noninteractive +ENV DEBCONF_NONINTERACTIVE_SEEN true + +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + liblapack-dev \ + libopenblas-dev \ + libpng-dev \ + libpython3-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + nodejs \ + npm \ + openjdk-17-jdk-headless \ + pandoc \ + pkg-config \ + qpdf \ + r-base \ + ruby \ + ruby-dev \ + software-properties-common \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + + +# See more in SPARK-39959, roxygen2 < 7.2.1 +RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', 'rmarkdown', 'testthat'), repos='https://cloud.r-project.org/')" && \ + Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" + +# See more in SPARK-39735 +ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" + +# Install Python 3.9 +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && apt-get install -y python3.9 python3.9-distutils \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 + +# Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 +# See 'ipython_genutils' in SPARK-38517 +# See 'docutils<0.18.0' in SPARK-39421 +RUN python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' \ + && python3.9 -m pip cache purge From 24d9c97916c109b9228706ea71d9e8e588beea95 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 07:40:58 -0700 Subject: [PATCH 027/108] [SPARK-50002][PYTHON][CONNECT] API compatibility check for I/O ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for I/O ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48511 from itholic/compat_readwriter. Authored-by: Haejoon Lee Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/connect/readwriter.py | 2 +- python/pyspark/sql/readwriter.py | 5 +- .../sql/tests/test_connect_compatibility.py | 58 ++++++++++++++++++- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 826cfdea8a9e3..aeb0f98d71076 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -229,7 +229,7 @@ def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame def text( self, paths: PathOrPaths, - wholetext: Optional[bool] = None, + wholetext: bool = False, lineSep: Optional[str] = None, pathGlobFilter: Optional[Union[bool, str]] = None, recursiveFileLookup: Optional[Union[bool, str]] = None, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3a0b5cdfead91..4744bdf861d37 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -18,7 +18,6 @@ from typing import cast, overload, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union from pyspark.util import is_remote_only -from pyspark.sql.column import Column from pyspark.sql.types import StructType from pyspark.sql import utils from pyspark.sql.utils import to_str @@ -2400,7 +2399,7 @@ def tableProperty(self, property: str, value: str) -> "DataFrameWriterV2": self._jwriter.tableProperty(property, value) return self - def partitionedBy(self, col: Column, *cols: Column) -> "DataFrameWriterV2": + def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFrameWriterV2": """ Partition the output table created by `create`, `createOrReplace`, or `replace` using the given columns or transforms. @@ -2487,7 +2486,7 @@ def append(self) -> None: """ self._jwriter.append() - def overwrite(self, condition: Column) -> None: + def overwrite(self, condition: "ColumnOrName") -> None: """ Overwrite rows matching the given filter condition with the contents of the data frame in the output table. diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index c125c905604e9..efef85862633e 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -25,12 +25,18 @@ from pyspark.sql.classic.column import Column as ClassicColumn from pyspark.sql.session import SparkSession as ClassicSparkSession from pyspark.sql.catalog import Catalog as ClassicCatalog +from pyspark.sql.readwriter import DataFrameReader as ClassicDataFrameReader +from pyspark.sql.readwriter import DataFrameWriter as ClassicDataFrameWriter +from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2 if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.sql.connect.column import Column as ConnectColumn from pyspark.sql.connect.session import SparkSession as ConnectSparkSession from pyspark.sql.connect.catalog import Catalog as ConnectCatalog + from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader + from pyspark.sql.connect.readwriter import DataFrameWriter as ConnectDataFrameWriter + from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2 class ConnectCompatibilityTestsMixin: @@ -63,7 +69,9 @@ def compare_method_signatures(self, classic_cls, connect_cls, cls_name): classic_signature = inspect.signature(classic_methods[method]) connect_signature = inspect.signature(connect_methods[method]) - if not method == "createDataFrame": + # Cannot support RDD arguments from Spark Connect + has_rdd_arguments = ("createDataFrame", "xml", "json") + if method not in has_rdd_arguments: self.assertEqual( classic_signature, connect_signature, @@ -247,6 +255,54 @@ def test_catalog_compatibility(self): expected_missing_classic_methods, ) + def test_dataframe_reader_compatibility(self): + """Test DataFrameReader compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicDataFrameReader, + ConnectDataFrameReader, + "DataFrameReader", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + + def test_dataframe_writer_compatibility(self): + """Test DataFrameWriter compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicDataFrameWriter, + ConnectDataFrameWriter, + "DataFrameWriter", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + + def test_dataframe_writer_v2_compatibility(self): + """Test DataFrameWriterV2 compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicDataFrameWriterV2, + ConnectDataFrameWriterV2, + "DataFrameWriterV2", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From d144144e3133453931ae9f9500231b2289f532fa Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 17 Oct 2024 07:55:36 -0700 Subject: [PATCH 028/108] [SPARK-50003][PYTHON][DOCS] Add pie plot doc examples and correct area plot parameter type hints ### What changes were proposed in this pull request? Add doc examples to `pie` plot and correct type hints of `area` plot parameter. ### Why are the changes needed? Improve readability and typing. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48512 from xinrong-meng/minor_fix. Authored-by: Xinrong Meng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/plot/core.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 4bf75474d92c3..e61af4ae3fa5d 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -270,7 +270,7 @@ def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure": """ return self(kind="scatter", x=x, y=y, **kwargs) - def area(self, x: str, y: str, **kwargs: Any) -> "Figure": + def area(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": """ Draw a stacked area plot. @@ -326,6 +326,16 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": Examples -------- + >>> from datetime import datetime + >>> data = [ + ... (3, 5, 20, datetime(2018, 1, 31)), + ... (2, 5, 42, datetime(2018, 2, 28)), + ... (3, 6, 28, datetime(2018, 3, 31)), + ... (9, 12, 62, datetime(2018, 4, 30)) + ... ] + >>> columns = ["sales", "signups", "visits", "date"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.pie(x='date', y='sales') # doctest: +SKIP """ schema = self.data.schema From 6bcad771246bf09a775fd5edb0874504f7057297 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 08:13:10 -0700 Subject: [PATCH 029/108] [SPARK-50008][PS][CONNECT] Avoid unnecessary operations in `attach_distributed_sequence_column` ### What changes were proposed in this pull request? Avoid unnecessary operations in `attach_distributed_sequence_column` ### Why are the changes needed? 1, `attach_distributed_sequence_column` always needs `sdf.columns`, which may trigger an analysis task in Spark Connect if the `sdf.schema` has not been cached; 2, for zero columns dataframe, it trigger `sdf.count` which seems redundant; ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #48516 from zhengruifeng/ps_attach_distributed_column. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/pandas/internal.py | 14 +--------- python/pyspark/pandas/tests/test_internal.py | 27 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 90c361547b814..363ef73302547 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -930,19 +930,7 @@ def attach_distributed_sequence_column( | 2| c| +--------+---+ """ - if len(sdf.columns) > 0: - return sdf.select( - SF.distributed_sequence_id().alias(column_name), - "*", - ) - else: - cnt = sdf.count() - if cnt > 0: - return default_session().range(cnt).toDF(column_name) - else: - return default_session().createDataFrame( - [], schema=StructType().add(column_name, data_type=LongType(), nullable=False) - ) + return sdf.select(SF.distributed_sequence_id().alias(column_name), "*") def spark_column_for(self, label: Label) -> PySparkColumn: """Return Spark Column for the given column label.""" diff --git a/python/pyspark/pandas/tests/test_internal.py b/python/pyspark/pandas/tests/test_internal.py index 5a936d2dcd634..741b81c94440d 100644 --- a/python/pyspark/pandas/tests/test_internal.py +++ b/python/pyspark/pandas/tests/test_internal.py @@ -17,6 +17,7 @@ import pandas as pd +from pyspark.sql.types import LongType, StructType, StructField from pyspark.pandas.internal import ( InternalFrame, SPARK_DEFAULT_INDEX_NAME, @@ -106,6 +107,32 @@ def test_from_pandas(self): self.assert_eq(internal.to_pandas_frame, pdf) + def test_attach_distributed_column(self): + sdf1 = self.spark.range(10) + self.assert_eq( + InternalFrame.attach_distributed_sequence_column(sdf1, "index").schema, + StructType( + [ + StructField("index", LongType(), False), + StructField("id", LongType(), False), + ] + ), + ) + + # zero columns + sdf2 = self.spark.range(10).select() + self.assert_eq( + InternalFrame.attach_distributed_sequence_column(sdf2, "index").schema, + StructType([StructField("index", LongType(), False)]), + ) + + # empty dataframe, zero columns + sdf3 = self.spark.range(10).where("id < 0").select() + self.assert_eq( + InternalFrame.attach_distributed_sequence_column(sdf3, "index").schema, + StructType([StructField("index", LongType(), False)]), + ) + class InternalFrameTests(InternalFrameTestsMixin, PandasOnSparkTestCase, SQLTestUtils): pass From 120ae9ac889a47aae5db9ae86dd050fe000579b4 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Thu, 17 Oct 2024 17:46:11 +0200 Subject: [PATCH 030/108] [SPARK-49958][PYTHON] Python API for string validation functions ### What changes were proposed in this pull request? Adding the Python API for the 4 new string validation expressions: - is_valid_utf8 - make_valid_utf8 - validate_utf8 - try_validate_utf8 ### Why are the changes needed? Offer a complete Python API for the new expressions in Spark 4.0. ### Does this PR introduce _any_ user-facing change? Yes, adding Python API for the 4 new Spark expressions. ### How was this patch tested? New tests for the Python API. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48455 from uros-db/api-validation-python. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../reference/pyspark.sql/functions.rst | 4 + .../pyspark/sql/connect/functions/builtin.py | 28 ++++ python/pyspark/sql/functions/builtin.py | 121 ++++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 19 ++- 4 files changed, 169 insertions(+), 3 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 6248e71331656..53904718fff6a 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -177,6 +177,7 @@ String Functions format_string initcap instr + is_valid_utf8 lcase left length @@ -185,6 +186,7 @@ String Functions lower lpad ltrim + make_valid_utf8 mask octet_length overlay @@ -218,9 +220,11 @@ String Functions trim try_to_binary try_to_number + try_validate_utf8 ucase unbase64 upper + validate_utf8 Bitwise Functions diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index db12e085468a0..9341442a1733b 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2445,6 +2445,34 @@ def encode(col: "ColumnOrName", charset: str) -> Column: encode.__doc__ = pysparkfuncs.encode.__doc__ +def is_valid_utf8(str: "ColumnOrName") -> Column: + return _invoke_function_over_columns("is_valid_utf8", _to_col(str)) + + +is_valid_utf8.__doc__ = pysparkfuncs.is_valid_utf8.__doc__ + + +def make_valid_utf8(str: "ColumnOrName") -> Column: + return _invoke_function_over_columns("make_valid_utf8", _to_col(str)) + + +make_valid_utf8.__doc__ = pysparkfuncs.make_valid_utf8.__doc__ + + +def validate_utf8(str: "ColumnOrName") -> Column: + return _invoke_function_over_columns("validate_utf8", _to_col(str)) + + +validate_utf8.__doc__ = pysparkfuncs.validate_utf8.__doc__ + + +def try_validate_utf8(str: "ColumnOrName") -> Column: + return _invoke_function_over_columns("try_validate_utf8", _to_col(str)) + + +try_validate_utf8.__doc__ = pysparkfuncs.try_validate_utf8.__doc__ + + def format_number(col: "ColumnOrName", d: int) -> Column: return _invoke_function("format_number", _to_col(col), lit(d)) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index b75d1b2f59faf..67c2e23b40ed8 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11197,6 +11197,127 @@ def encode(col: "ColumnOrName", charset: str) -> Column: return _invoke_function("encode", _to_java_column(col), _enum_to_value(charset)) +@_try_remote_functions +def is_valid_utf8(str: "ColumnOrName") -> Column: + """ + Returns true if the input is a valid UTF-8 string, otherwise returns false. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + str : :class:`~pyspark.sql.Column` or str + A column of strings, each representing a UTF-8 byte sequence. + + Returns + ------- + :class:`~pyspark.sql.Column` + whether the input string is a valid UTF-8 string. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.is_valid_utf8(sf.lit("SparkSQL"))).show() + +-----------------------+ + |is_valid_utf8(SparkSQL)| + +-----------------------+ + | true| + +-----------------------+ + """ + return _invoke_function_over_columns("is_valid_utf8", str) + + +@_try_remote_functions +def make_valid_utf8(str: "ColumnOrName") -> Column: + """ + Returns a new string in which all invalid UTF-8 byte sequences, if any, are replaced by the + Unicode replacement character (U+FFFD). + + .. versionadded:: 4.0.0 + + Parameters + ---------- + str : :class:`~pyspark.sql.Column` or str + A column of strings, each representing a UTF-8 byte sequence. + + Returns + ------- + :class:`~pyspark.sql.Column` + the valid UTF-8 version of the given input string. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.make_valid_utf8(sf.lit("SparkSQL"))).show() + +-------------------------+ + |make_valid_utf8(SparkSQL)| + +-------------------------+ + | SparkSQL| + +-------------------------+ + """ + return _invoke_function_over_columns("make_valid_utf8", str) + + +@_try_remote_functions +def validate_utf8(str: "ColumnOrName") -> Column: + """ + Returns the input value if it corresponds to a valid UTF-8 string, or emits an error otherwise. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + str : :class:`~pyspark.sql.Column` or str + A column of strings, each representing a UTF-8 byte sequence. + + Returns + ------- + :class:`~pyspark.sql.Column` + the input string if it is a valid UTF-8 string, error otherwise. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.validate_utf8(sf.lit("SparkSQL"))).show() + +-----------------------+ + |validate_utf8(SparkSQL)| + +-----------------------+ + | SparkSQL| + +-----------------------+ + """ + return _invoke_function_over_columns("validate_utf8", str) + + +@_try_remote_functions +def try_validate_utf8(str: "ColumnOrName") -> Column: + """ + Returns the input value if it corresponds to a valid UTF-8 string, or NULL otherwise. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + str : :class:`~pyspark.sql.Column` or str + A column of strings, each representing a UTF-8 byte sequence. + + Returns + ------- + :class:`~pyspark.sql.Column` + the input string if it is a valid UTF-8 string, null otherwise. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.try_validate_utf8(sf.lit("SparkSQL"))).show() + +---------------------------+ + |try_validate_utf8(SparkSQL)| + +---------------------------+ + | SparkSQL| + +---------------------------+ + """ + return _invoke_function_over_columns("try_validate_utf8", str) + + @_try_remote_functions def format_number(col: "ColumnOrName", d: int) -> Column: """ diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f6c1278c0dc7a..cec6e2ababbdc 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -83,9 +83,7 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set( - ["is_valid_utf8", "make_valid_utf8", "validate_utf8", "try_validate_utf8"] - ) + expected_missing_in_py = set() self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" @@ -1631,6 +1629,21 @@ def test_randstr_uniform(self): result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect() self.assertEqual([Row(True)], result) + def test_string_validation(self): + df = self.spark.createDataFrame([("abc",)], ["a"]) + # test is_valid_utf8 + result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r")).collect() + self.assertEqual([Row(r=True)], result_is_valid_utf8) + # test make_valid_utf8 + result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r")).collect() + self.assertEqual([Row(r="abc")], result_make_valid_utf8) + # test validate_utf8 + result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r")).collect() + self.assertEqual([Row(r="abc")], result_validate_utf8) + # test try_validate_utf8 + result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r")).collect() + self.assertEqual([Row(r="abc")], result_try_validate_utf8) + class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): pass From 6362e0c30fab11bb2bd13df21484982fc86e91e3 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 17 Oct 2024 08:56:36 -0700 Subject: [PATCH 031/108] [SPARK-49936][BUILD] Upgrade `datasketches-java` to 6.1.1 ### What changes were proposed in this pull request? This pr aims to upgrade `datasketches-java` from 6.0.0 to 6.1.1. ### Why are the changes needed? The new version is now dependent on `datasketches-memory` 3.x. The full release notes as follows: - https://github.com/apache/datasketches-java/releases/tag/6.1.0 - https://github.com/apache/datasketches-java/releases/tag/6.1.1 - https://github.com/apache/datasketches-memory/releases/tag/3.0.0 - https://github.com/apache/datasketches-memory/releases/tag/3.0.1 - https://github.com/apache/datasketches-memory/releases/tag/3.0.2 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #48380 from LuciferYang/test-dm-3.0.2. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 4 ++-- pom.xml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 91e84b0780798..4620a51904461 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -57,8 +57,8 @@ curator-recipes/5.7.0//curator-recipes-5.7.0.jar datanucleus-api-jdo/4.2.4//datanucleus-api-jdo-4.2.4.jar datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar -datasketches-java/6.0.0//datasketches-java-6.0.0.jar -datasketches-memory/2.2.0//datasketches-memory-2.2.0.jar +datasketches-java/6.1.1//datasketches-java-6.1.1.jar +datasketches-memory/3.0.2//datasketches-memory-3.0.2.jar derby/10.16.1.1//derby-10.16.1.1.jar derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar derbytools/10.16.1.1//derbytools-10.16.1.1.jar diff --git a/pom.xml b/pom.xml index 2e169df7201c2..fe49568d744a0 100644 --- a/pom.xml +++ b/pom.xml @@ -214,7 +214,7 @@ 1.9.0 1.78 1.15.0 - 6.0.0 + 6.1.1 4.1.110.Final 2.0.66.Final 75.1 From 91becf140cfe57d19c450876c48145f1c93b54e7 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 17 Oct 2024 19:58:31 +0200 Subject: [PATCH 032/108] [SPARK-49962][SQL] Simplify AbstractStringTypes class hierarchy ### What changes were proposed in this pull request? Simplifying the AbstractStringType hierarchy. ### Why are the changes needed? The addition of trim-sensitive collation (#48336) highlighted the complexity of extending the existing AbstractStringType structure. Besides adding a new parameter to all types inheriting from AbstractStringType, it caused changing the logic of every subclass as well as changing the name of a derived class StringTypeAnyCollation into StringTypeWithCaseAccentSensitivity which could again be subject to change if we keep adding new specifiers. Looking ahead, the introduction of support for indeterminate collation would further complicate these types. To address this, the proposed changes simplify the design by consolidating common logic into a single base class. This base class will handle core functionality such as trim or indeterminate collation, while a derived class, StringTypeWithCollation (previously awkwardly called StringTypeWithCaseAccentSensitivity), will manage collation specifiers. This approach allows for easier future extensions: fundamental checks can be handled in the base class, while any new specifiers can be added as optional fields in StringTypeWithCollation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48459 from stefankandic/refactorStringTypes. Authored-by: Stefan Kandic Signed-off-by: Max Gekk --- .../sql/catalyst/util/CollationFactory.java | 14 ++- .../internal/types/AbstractStringType.scala | 82 +++++++++++------- .../apache/spark/sql/types/StringType.scala | 7 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../expressions/CallMethodViaReflection.scala | 8 +- .../catalyst/expressions/CollationKey.scala | 4 +- .../sql/catalyst/expressions/ExprUtils.scala | 4 +- .../aggregate/datasketchesAggregates.scala | 4 +- .../expressions/collationExpressions.scala | 6 +- .../expressions/collectionOperations.scala | 16 ++-- .../catalyst/expressions/csvExpressions.scala | 4 +- .../expressions/datetimeExpressions.scala | 30 +++---- .../expressions/jsonExpressions.scala | 12 +-- .../expressions/maskExpressions.scala | 12 +-- .../expressions/mathExpressions.scala | 8 +- .../spark/sql/catalyst/expressions/misc.scala | 12 +-- .../expressions/numberFormatExpressions.scala | 6 +- .../expressions/regexpExpressions.scala | 16 ++-- .../expressions/stringExpressions.scala | 86 +++++++++---------- .../catalyst/expressions/urlExpressions.scala | 12 +-- .../variant/variantExpressions.scala | 6 +- .../sql/catalyst/expressions/xml/xpath.scala | 6 +- .../catalyst/expressions/xmlExpressions.scala | 4 +- .../analysis/AnsiTypeCoercionSuite.scala | 18 ++-- .../expressions/StringExpressionsSuite.scala | 4 +- .../sql/CollationExpressionWalkerSuite.scala | 44 +++++----- 26 files changed, 226 insertions(+), 203 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 50bb93465921e..4a61e630fef39 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -1157,15 +1157,13 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } - /** - * Returns whether the ICU collation is not Case Sensitive Accent Insensitive - * for the given collation id. - * This method is used in expressions which do not support CS_AI collations. - */ - public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { + public static boolean isCaseInsensitive(int collationId) { return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == - Collation.CollationSpecICU.CaseSensitivity.CS && - Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CI; + } + + public static boolean isAccentInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == Collation.CollationSpecICU.AccentSensitivity.AI; } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index c3643f4bd15be..3a25bba32b530 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -21,25 +21,34 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** - * AbstractStringType is an abstract class for StringType with collation support. As every type of - * collation can support trim specifier this class is parametrized with it. + * AbstractStringType is an abstract class for StringType with collation support. */ -abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false) +abstract class AbstractStringType(supportsTrimCollation: Boolean = false) extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType override private[sql] def simpleString: String = "string" - private[sql] def canUseTrimCollation(other: DataType): Boolean = - supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation + + override private[sql] def acceptsType(other: DataType): Boolean = other match { + case st: StringType => + canUseTrimCollation(st) && acceptsStringType(st) + case _ => + false + } + + private[sql] def canUseTrimCollation(other: StringType): Boolean = + supportsTrimCollation || !other.usesTrimCollation + + def acceptsStringType(other: StringType): Boolean } /** - * Use StringTypeBinary for expressions supporting only binary collation. + * Used for expressions supporting only binary collation. */ -case class StringTypeBinary(override val supportsTrimCollation: Boolean = false) +case class StringTypeBinary(supportsTrimCollation: Boolean) extends AbstractStringType(supportsTrimCollation) { - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality && - canUseTrimCollation(other) + + override def acceptsStringType(other: StringType): Boolean = + other.supportsBinaryEquality } object StringTypeBinary extends StringTypeBinary(false) { @@ -49,13 +58,13 @@ object StringTypeBinary extends StringTypeBinary(false) { } /** - * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. + * Used for expressions supporting only binary and lowercase collation. */ -case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false) +case class StringTypeBinaryLcase(supportsTrimCollation: Boolean) extends AbstractStringType(supportsTrimCollation) { - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality || - other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other) + + override def acceptsStringType(other: StringType): Boolean = + other.supportsBinaryEquality || other.isUTF8LcaseCollation } object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) { @@ -65,31 +74,44 @@ object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) { } /** - * Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary - * and ICU) but limited to using case and accent sensitivity specifiers. + * Used for expressions supporting collation types with optional case, accent, and trim + * sensitivity specifiers. + * + * Case and accent sensitivity specifiers are supported by default. */ -case class StringTypeWithCaseAccentSensitivity( - override val supportsTrimCollation: Boolean = false) +case class StringTypeWithCollation( + supportsTrimCollation: Boolean, + supportsCaseSpecifier: Boolean, + supportsAccentSpecifier: Boolean) extends AbstractStringType(supportsTrimCollation) { - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && canUseTrimCollation(other) + + override def acceptsStringType(other: StringType): Boolean = { + (supportsCaseSpecifier || !other.isCaseInsensitive) && + (supportsAccentSpecifier || !other.isAccentInsensitive) + } } -object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) { - def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = { - new StringTypeWithCaseAccentSensitivity(supportsTrimCollation) +object StringTypeWithCollation extends StringTypeWithCollation(false, true, true) { + def apply( + supportsTrimCollation: Boolean = false, + supportsCaseSpecifier: Boolean = true, + supportsAccentSpecifier: Boolean = true): StringTypeWithCollation = { + new StringTypeWithCollation( + supportsTrimCollation, + supportsCaseSpecifier, + supportsAccentSpecifier) } } /** - * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except - * CS_AI collation types. + * Used for expressions supporting all possible collation types except those that are + * case-sensitive but accent insensitive (CS_AI). */ -case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false) +case class StringTypeNonCSAICollation(supportsTrimCollation: Boolean) extends AbstractStringType(supportsTrimCollation) { - override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI && - canUseTrimCollation(other) + + override def acceptsStringType(other: StringType): Boolean = + other.isCaseInsensitive || !other.isAccentInsensitive } object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 1c93c2ad550e9..1eb645e37c4aa 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -44,8 +44,11 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsLowercaseEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsLowercaseEquality - private[sql] def isNonCSAI: Boolean = - !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) + private[sql] def isCaseInsensitive: Boolean = + CollationFactory.isCaseInsensitive(collationId) + + private[sql] def isAccentInsensitive: Boolean = + CollationFactory.isAccentInsensitive(collationId) private[sql] def usesTrimCollation: Boolean = CollationFactory.fetchCollation(collationId).supportsSpaceTrimming diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e0298b19931c7..0f89fcd287bb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, - StringTypeWithCaseAccentSensitivity} + StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -439,7 +439,7 @@ abstract class TypeCoercionBase { } case aj @ ArrayJoin(arr, d, nr) - if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) && + if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) && ArrayType.acceptsType(arr.dataType) => val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull implicitCast(arr, ArrayType(StringType, containsNull)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 6aa11b6fd16df..d38ee01485288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -84,7 +84,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("class"), - "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), + "inputType" -> toSQLType(StringTypeWithCollation), "inputExpr" -> toSQLExpr(children.head) ) ) @@ -97,7 +97,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("method"), - "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), + "inputType" -> toSQLType(StringTypeWithCollation), "inputExpr" -> toSQLExpr(children(1)) ) ) @@ -115,7 +115,7 @@ case class CallMethodViaReflection( "requiredType" -> toSQLType( TypeCollection(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - StringTypeWithCaseAccentSensitivity)), + StringTypeWithCollation)), "inputSql" -> toSQLExpr(e), "inputType" -> toSQLType(e.dataType)) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 81bafda54135f..5d2fd14eee298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override def dataType: DataType = BinaryType final lazy val collationId: Int = expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 9c617b51df62f..e65a0200b064f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity} +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation} import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String @@ -61,7 +61,7 @@ object ExprUtils extends EvalHelper with QueryErrorsBase { def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap - if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation) .acceptsType(m.dataType) => val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala index a6448051a3996..cbc8a8f273e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -109,7 +109,7 @@ case class HllSketchAgg( TypeCollection( IntegerType, LongType, - StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true), + StringTypeWithCollation(supportsTrimCollation = true), BinaryType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index effcdc4b038e5..c75bf30ad21f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ // scalastyle:off line.contains.tab @@ -78,7 +78,7 @@ case class Collate(child: Expression, collationName: String) private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) + Seq(StringTypeWithCollation(supportsTrimCollation = true)) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) @@ -117,5 +117,5 @@ case class Collation(child: Expression) Literal.create(collationName, SQLConf.get.defaultStringType) } override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) + Seq(StringTypeWithCollation(supportsTrimCollation = true)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bb54749126860..0d563530bcbcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder @@ -1349,7 +1349,7 @@ case class Reverse(child: Expression) // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType)) + Seq(TypeCollection(StringTypeWithCollation, ArrayType)) override def dataType: DataType = child.dataType @@ -2135,12 +2135,12 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity) + Seq(AbstractArrayType(StringTypeWithCollation), + StringTypeWithCollation, + StringTypeWithCollation) } else { - Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), - StringTypeWithCaseAccentSensitivity) + Seq(AbstractArrayType(StringTypeWithCollation), + StringTypeWithCollation) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2861,7 +2861,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio with QueryErrorsBase { private def allowedTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType) + Seq(StringTypeWithCollation, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 2f4462c0664f8..cdad9938c5d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -147,7 +147,7 @@ case class CsvToStructs( converter(parser.parse(csv)) } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil override def prettyName: String = "from_csv" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 764637b97a100..de3501a671eb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -963,7 +963,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = - Seq(TimestampType, StringTypeWithCaseAccentSensitivity) + Seq(TimestampType, StringTypeWithCollation) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1272,9 +1272,9 @@ abstract class ToTimestamp override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection( - StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType + StringTypeWithCollation, DateType, TimestampType, TimestampNTZType ), - StringTypeWithCaseAccentSensitivity) + StringTypeWithCollation) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1446,7 +1446,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(LongType, StringTypeWithCaseAccentSensitivity) + Seq(LongType, StringTypeWithCollation) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1555,7 +1555,7 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) override def inputTypes: Seq[AbstractDataType] = - Seq(DateType, StringTypeWithCaseAccentSensitivity) + Seq(DateType, StringTypeWithCollation) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1767,7 +1767,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val funcName: String override def inputTypes: Seq[AbstractDataType] = - Seq(TimestampType, StringTypeWithCaseAccentSensitivity) + Seq(TimestampType, StringTypeWithCollation) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2108,8 +2108,8 @@ case class ParseToDate( // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. TypeCollection( - StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq + StringTypeWithCollation, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeWithCollation).toSeq } override protected def withNewChildrenInternal( @@ -2180,10 +2180,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeWithCollation, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq + ) +: format.map(_ => StringTypeWithCollation).toSeq } override protected def withNewChildrenInternal( @@ -2314,7 +2314,7 @@ case class TruncDate(date: Expression, format: Expression) override def right: Expression = format override def inputTypes: Seq[AbstractDataType] = - Seq(DateType, StringTypeWithCaseAccentSensitivity) + Seq(DateType, StringTypeWithCollation) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2384,7 +2384,7 @@ case class TruncTimestamp( override def right: Expression = timestamp override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, TimestampType) + Seq(StringTypeWithCollation, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2685,7 +2685,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringTypeWithCaseAccentSensitivity) + timezone.map(_ => StringTypeWithCollation) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -3133,7 +3133,7 @@ case class ConvertTimezone( override def third: Expression = sourceTs override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, TimestampNTZType) + Seq(StringTypeWithCollation, StringTypeWithCollation, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 6eef3d6f9d7df..a553336015b88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePatt import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils @@ -134,7 +134,7 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -491,7 +491,7 @@ case class JsonTuple(children: Seq[Expression]) ) } else if ( children.forall( - child => StringTypeWithCaseAccentSensitivity.acceptsType(child.dataType))) { + child => StringTypeWithCollation.acceptsType(child.dataType))) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( @@ -709,7 +709,7 @@ case class JsonToStructs( |""".stripMargin) } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -949,7 +949,7 @@ case class LengthOfJsonArray(child: Expression) with ExpectsInputTypes with RuntimeReplaceable { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" @@ -994,7 +994,7 @@ case class JsonObjectKeys(child: Expression) with ExpectsInputTypes with RuntimeReplaceable { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index cb62fa2cc3bd5..7be6df14194fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -193,11 +193,11 @@ case class Mask( */ override def inputTypes: Seq[AbstractDataType] = Seq( - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity) + StringTypeWithCollation, + StringTypeWithCollation, + StringTypeWithCollation, + StringTypeWithCollation, + StringTypeWithCollation) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e46acf467db22..71fd43a8d9423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -453,7 +453,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, IntegerType, IntegerType) + Seq(StringTypeWithCollation, IntegerType, IntegerType) override def dataType: DataType = first.dataType override def nullable: Boolean = true @@ -1114,7 +1114,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringTypeWithCaseAccentSensitivity)) + Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation)) override def dataType: DataType = child.dataType match { case st: StringType => st @@ -1158,7 +1158,7 @@ case class Unhex(child: Expression, failOnError: Boolean = false) def this(expr: Expression) = this(expr, false) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def nullable: Boolean = true override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0315c12b9bb8c..bef3bac17ffd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, MapType(StringType, StringType)) + Seq(StringTypeWithCollation, MapType(StringType, StringType)) override def left: Expression = errorClass override def right: Expression = errorParms @@ -417,8 +417,8 @@ case class AesEncrypt( override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, + StringTypeWithCollation, + StringTypeWithCollation, BinaryType, BinaryType) override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad) @@ -494,8 +494,8 @@ case class AesDecrypt( override def inputTypes: Seq[AbstractDataType] = { Seq(BinaryType, BinaryType, - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, BinaryType) + StringTypeWithCollation, + StringTypeWithCollation, BinaryType) } override def prettyName: String = "aes_decrypt" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index eefd21b236b7f..f2fb735b163e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -50,7 +50,7 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo } override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -285,7 +285,7 @@ case class ToCharacter(left: Expression, right: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = - Seq(DecimalType, StringTypeWithCaseAccentSensitivity) + Seq(DecimalType, StringTypeWithCollation) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index fdc3c27890469..ba4f145888cf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTR import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.types.{ - StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity} + StringTypeBinaryLcase, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -47,7 +47,7 @@ abstract class StringRegexExpression extends BinaryExpression def matches(regex: Pattern, str: String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeBinaryLcase, StringTypeWithCollation) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId) @@ -279,7 +279,7 @@ case class ILike( this(left, right, '\\') override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeBinaryLcase, StringTypeWithCollation) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { @@ -568,7 +568,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(str.dataType, containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCollation, IntegerType) override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -713,7 +713,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def dataType: DataType = subject.dataType override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeBinaryLcase, - StringTypeWithCaseAccentSensitivity, StringTypeBinaryLcase, IntegerType) + StringTypeWithCollation, StringTypeBinaryLcase, IntegerType) final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def prettyName: String = "regexp_replace" @@ -801,7 +801,7 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCollation, IntegerType) override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -1054,7 +1054,7 @@ case class RegExpCount(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeBinaryLcase, StringTypeWithCollation) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = @@ -1094,7 +1094,7 @@ case class RegExpSubStr(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeBinaryLcase, StringTypeWithCollation) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c91c57ee1eb3e..4367920f939e4 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, Collation import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, - StringTypeNonCSAICollation, StringTypeWithCaseAccentSensitivity} + StringTypeNonCSAICollation, StringTypeWithCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -82,10 +82,10 @@ case class ConcatWs(children: Seq[Expression]) /** The 1st child (separator) is str, and rest are either str or array of str. */ override def inputTypes: Seq[AbstractDataType] = { val arrayOrStr = - TypeCollection(AbstractArrayType(StringTypeWithCaseAccentSensitivity), - StringTypeWithCaseAccentSensitivity + TypeCollection(AbstractArrayType(StringTypeWithCollation), + StringTypeWithCollation ) - StringTypeWithCaseAccentSensitivity +: Seq.fill(children.size - 1)(arrayOrStr) + StringTypeWithCollation +: Seq.fill(children.size - 1)(arrayOrStr) } override def dataType: DataType = children.head.dataType @@ -436,7 +436,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -518,7 +518,7 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation) protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) @@ -735,7 +735,7 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def nodeName: String = "is_valid_utf8" @@ -782,7 +782,7 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def nodeName: String = "make_valid_utf8" @@ -827,7 +827,7 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def nodeName: String = "validate_utf8" @@ -876,7 +876,7 @@ case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with Im Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def nodeName: String = "try_validate_utf8" @@ -1011,8 +1011,8 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), - TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) + TypeCollection(StringTypeWithCollation, BinaryType), + TypeCollection(StringTypeWithCollation, BinaryType), IntegerType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -1216,7 +1216,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation) override protected def nullSafeEval(word: Any, set: Any): Any = { CollationSupport.FindInSet. @@ -1245,7 +1245,7 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = srcStr +: trimStr.toSeq override def dataType: DataType = srcStr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) + Seq.fill(children.size)(StringTypeWithCollation) final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId @@ -1850,7 +1850,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, IntegerType, StringTypeWithCollation) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1930,7 +1930,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, IntegerType, StringTypeWithCollation) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1975,7 +1975,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def dataType: DataType = children(0).dataType override def inputTypes: Seq[AbstractDataType] = - StringTypeWithCaseAccentSensitivity :: List.fill(children.size - 1)(AnyDataType) + StringTypeWithCollation :: List.fill(children.size - 1)(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2086,7 +2086,7 @@ case class InitCap(child: Expression) // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { @@ -2119,7 +2119,7 @@ case class StringRepeat(str: Expression, times: Expression) override def right: Expression = times override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, IntegerType) + Seq(StringTypeWithCollation, IntegerType) override def nullSafeEval(string: Any, n: Any): Any = { string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) @@ -2212,7 +2212,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeWithCollation, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos @@ -2271,7 +2271,7 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable ) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, IntegerType) + Seq(StringTypeWithCollation, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2302,7 +2302,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeWithCollation, BinaryType), IntegerType) } override def left: Expression = str @@ -2338,7 +2338,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) + Seq(TypeCollection(StringTypeWithCollation, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numChars @@ -2373,7 +2373,7 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) + Seq(TypeCollection(StringTypeWithCollation, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes * 8 @@ -2412,7 +2412,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) + Seq(TypeCollection(StringTypeWithCollation, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes @@ -2473,8 +2473,8 @@ case class Levenshtein( override def inputTypes: Seq[AbstractDataType] = threshold match { case Some(_) => - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, IntegerType) - case _ => Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation, IntegerType) + case _ => Seq(StringTypeWithCollation, StringTypeWithCollation) } override def children: Seq[Expression] = threshold match { @@ -2599,7 +2599,7 @@ case class SoundEx(child: Expression) override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() @@ -2629,7 +2629,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) protected override def nullSafeEval(string: Any): Any = { // only pick the first character to reduce the `toString` cost @@ -2774,7 +2774,7 @@ case class UnBase64(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) def this(expr: Expression) = this(expr, false) @@ -2954,7 +2954,7 @@ case class StringDecode( override val dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = - Seq(BinaryType, StringTypeWithCaseAccentSensitivity) + Seq(BinaryType, StringTypeWithCollation) override def prettyName: String = "decode" override def toString: String = s"$prettyName($bin, $charset)" @@ -2963,7 +2963,7 @@ case class StringDecode( SQLConf.get.defaultStringType, "decode", Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)), - Seq(BinaryType, StringTypeWithCaseAccentSensitivity, BooleanType, BooleanType)) + Seq(BinaryType, StringTypeWithCollation, BooleanType, BooleanType)) override def children: Seq[Expression] = Seq(bin, charset) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -3020,7 +3020,7 @@ case class Encode( override def dataType: DataType = BinaryType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation) override lazy val replacement: Expression = StaticInvoke( classOf[Encode], @@ -3030,8 +3030,8 @@ case class Encode( str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType) ), Seq( - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, + StringTypeWithCollation, + StringTypeWithCollation, BooleanType, BooleanType)) @@ -3118,7 +3118,7 @@ case class ToBinary( override def children: Seq[Expression] = expr +: format.toSeq override def inputTypes: Seq[AbstractDataType] = - children.map(_ => StringTypeWithCaseAccentSensitivity) + children.map(_ => StringTypeWithCollation) override def checkInputDataTypes(): TypeCheckResult = { def isValidFormat: Boolean = { @@ -3135,7 +3135,7 @@ case class ToBinary( messageParameters = Map( "inputName" -> "fmt", "requireType" -> - s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", + s"case-insensitive ${toSQLType(StringTypeWithCollation)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(fmt, f.dataType) ) @@ -3146,7 +3146,7 @@ case class ToBinary( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), + "inputType" -> toSQLType(StringTypeWithCollation), "inputExpr" -> toSQLExpr(f) ) ) @@ -3156,7 +3156,7 @@ case class ToBinary( messageParameters = Map( "inputName" -> "fmt", "requireType" -> - s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", + s"case-insensitive ${toSQLType(StringTypeWithCollation)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(f.eval(), f.dataType) ) @@ -3205,7 +3205,7 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(NumericType, TypeCollection(IntegerType, StringTypeWithCaseAccentSensitivity)) + Seq(NumericType, TypeCollection(IntegerType, StringTypeWithCollation)) private val defaultFormat = "#,###,###,###,###,###,##0" @@ -3411,8 +3411,8 @@ case class Sentences( ArrayType(ArrayType(str.dataType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = Seq( - StringTypeWithCaseAccentSensitivity, - StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + StringTypeWithCollation, + StringTypeWithCollation, StringTypeWithCollation) override def first: Expression = str override def second: Expression = language override def third: Expression = country @@ -3560,7 +3560,7 @@ case class Luhncheck(input: Expression) extends RuntimeReplaceable with Implicit "isLuhnNumber", Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def prettyName: String = "luhn_check" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 09e91da65484f..95f22663eb59a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -59,13 +59,13 @@ case class UrlEncode(child: Expression) SQLConf.get.defaultStringType, "encode", Seq(child), - Seq(StringTypeWithCaseAccentSensitivity)) + Seq(StringTypeWithCollation)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def prettyName: String = "url_encode" } @@ -98,13 +98,13 @@ case class UrlDecode(child: Expression, failOnError: Boolean = true) SQLConf.get.defaultStringType, "decode", Seq(child, Literal(failOnError)), - Seq(StringTypeWithCaseAccentSensitivity, BooleanType)) + Seq(StringTypeWithCollation, BooleanType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation) override def prettyName: String = "url_decode" } @@ -191,7 +191,7 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) + Seq.fill(children.size)(StringTypeWithCollation) override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 323f6e42f3e50..1a7c3b0ea59de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ import org.apache.spark.types.variant.VariantUtil.{IntervalFields, Type} @@ -66,7 +66,7 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) inputTypes :+ BooleanType :+ BooleanType, returnNullable = !failOnError) - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil override def dataType: DataType = VariantType @@ -272,7 +272,7 @@ case class VariantGet( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET) override def inputTypes: Seq[AbstractDataType] = - Seq(VariantType, StringTypeWithCaseAccentSensitivity) + Seq(VariantType, StringTypeWithCollation) override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 6c38bd88144b1..c694067e06abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ abstract class XPathExtract override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + Seq(StringTypeWithCollation, StringTypeWithCollation) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -50,7 +50,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), + "inputType" -> toSQLType(StringTypeWithCollation), "inputExpr" -> toSQLExpr(path) ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 6f1430b04ed67..f3f652b393f76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -124,7 +124,7 @@ case class XmlToStructs( defineCodeGen(ctx, ev, input => s"(InternalRow) $expr.nullSafeEval($input)") } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil override def prettyName: String = "from_xml" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index 342dcbd8e6b6d..8cf7d78b510be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCollation} import org.apache.spark.sql.types._ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { @@ -1057,11 +1057,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(IntegerType)) shouldCast( ArrayType(StringType), - AbstractArrayType(StringTypeWithCaseAccentSensitivity), + AbstractArrayType(StringTypeWithCollation), ArrayType(StringType)) shouldCast( ArrayType(IntegerType), - AbstractArrayType(StringTypeWithCaseAccentSensitivity), + AbstractArrayType(StringTypeWithCollation), ArrayType(StringType)) shouldCast( ArrayType(StringType), @@ -1075,11 +1075,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(ArrayType(IntegerType))) shouldCast( ArrayType(ArrayType(StringType)), - AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), + AbstractArrayType(AbstractArrayType(StringTypeWithCollation)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(IntegerType)), - AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), + AbstractArrayType(AbstractArrayType(StringTypeWithCollation)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(StringType)), @@ -1088,16 +1088,16 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { // Invalid casts involving casting arrays into non-complex types. shouldNotCast(ArrayType(IntegerType), IntegerType) - shouldNotCast(ArrayType(StringType), StringTypeWithCaseAccentSensitivity) + shouldNotCast(ArrayType(StringType), StringTypeWithCollation) shouldNotCast(ArrayType(StringType), IntegerType) - shouldNotCast(ArrayType(IntegerType), StringTypeWithCaseAccentSensitivity) + shouldNotCast(ArrayType(IntegerType), StringTypeWithCollation) // Invalid casts involving casting arrays of arrays into arrays of non-complex types. shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(IntegerType)) shouldNotCast(ArrayType(ArrayType(StringType)), - AbstractArrayType(StringTypeWithCaseAccentSensitivity)) + AbstractArrayType(StringTypeWithCollation)) shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(IntegerType)) shouldNotCast(ArrayType(ArrayType(IntegerType)), - AbstractArrayType(StringTypeWithCaseAccentSensitivity)) + AbstractArrayType(StringTypeWithCollation)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 1aae2f10b7326..aa7eafeed485a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.util.CharsetProvider import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity +import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1466,7 +1466,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), + "inputType" -> toSQLType(StringTypeWithCollation), "inputExpr" -> toSQLExpr(wrongFmt) ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 8600ec4f8787f..2b49b76ff8c7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -66,10 +66,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi collationType: CollationType): Any = inputEntry match { case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => - generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + generateLiterals(StringTypeWithCollation, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - CreateArray(Seq(generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), - generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) + CreateArray(Seq(generateLiterals(StringTypeWithCollation, collationType), + generateLiterals(StringTypeWithCollation, collationType))) case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType @@ -142,12 +142,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case ArrayType => - generateLiterals(StringTypeWithCaseAccentSensitivity, collationType).map( + generateLiterals(StringTypeWithCollation, collationType).map( lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case MapType => - val key = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) - val value = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + val key = generateLiterals(StringTypeWithCollation, collationType) + val value = generateLiterals(StringTypeWithCollation, collationType) CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) @@ -160,8 +160,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case StructType => CreateNamedStruct( Seq(Literal("start"), - generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), - Literal("end"), generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) + generateLiterals(StringTypeWithCollation, collationType), + Literal("end"), generateLiterals(StringTypeWithCollation, collationType))) } /** @@ -210,10 +210,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array(" + generateInputAsString(elementType, collationType) + ")" case ArrayType => - "array(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" + "array(" + generateInputAsString(StringTypeWithCollation, collationType) + ")" case MapType => - "map(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", " + - generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" + "map(" + generateInputAsString(StringTypeWithCollation, collationType) + ", " + + generateInputAsString(StringTypeWithCollation, collationType) + ")" case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" @@ -222,8 +222,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi generateInputAsString(valueType, collationType) + ")" case StructType => "named_struct( 'start', " + - generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", 'end', " + - generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" + generateInputAsString(StringTypeWithCollation, collationType) + ", 'end', " + + generateInputAsString(StringTypeWithCollation, collationType) + ")" case StructType(fields) => "named_struct(" + fields.map(f => "'" + f.name + "', " + generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" @@ -269,12 +269,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" case ArrayType => - "array<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + "array<" + generateInputTypeAsStrings(StringTypeWithCollation, collationType) + ">" case MapType => - "map<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + "map<" + generateInputTypeAsStrings(StringTypeWithCollation, collationType) + ", " + - generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" + generateInputTypeAsStrings(StringTypeWithCollation, collationType) + ">" case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" @@ -283,9 +283,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => "struct" + generateInputTypeAsStrings(StringTypeWithCollation, collationType) + ">" case StructType(fields) => "named_struct<" + fields.map(f => "'" + f.name + "', " + generateInputTypeAsStrings(f.dataType, collationType)).mkString(", ") + ">" @@ -298,7 +298,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi */ def hasStringType(inputType: AbstractDataType): Boolean = { inputType match { - case _: StringType | StringTypeWithCaseAccentSensitivity | StringTypeBinaryLcase | AnyDataType + case _: StringType | StringTypeWithCollation | StringTypeBinaryLcase | AnyDataType => true case ArrayType => true case MapType => true @@ -413,7 +413,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var i = 0 for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCollation, Utf8Binary) try { method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] } @@ -503,7 +503,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCollation, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression] @@ -614,7 +614,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCollation, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression] From bdd34647c42a14b5856d74c7cb76c7b93d26079f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 17 Oct 2024 11:26:19 -0700 Subject: [PATCH 033/108] [SPARK-50014][INFRA] Use `grpcio*` 1.67.0 in Python 3.13 image ### What changes were proposed in this pull request? This PR aims to use `grpcio` and `grpcio-status` `1.67.0` for Python 3.13 tests in order to reveal the remaining test failures after installing the official `grpcio` in Python 3.13 environment. ### Why are the changes needed? `grpcio` added Python 3.13 support since 1.66.2. - https://pypi.org/project/grpcio/1.67.0/ - https://pypi.org/project/grpcio/1.66.2/ ### Does this PR introduce _any_ user-facing change? No, this is an infra change for test coverage. Currently, `pyspark-connect` module test fails due to the missing required package, `grpc`, like the following. - https://github.com/apache/spark/actions/runs/11372942311/job/31638495254 ``` ModuleNotFoundError: No module named 'grpc' ``` ### How was this patch tested? Manual check the generated image of this PR builder. ``` $ docker run -it --rm ghcr.io/dongjoon-hyun/apache-spark-ci-image:master-11389776259 python3.13 -m pip list | grep grpcio grpcio 1.67.0 grpcio-status 1.67.0 ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48522 from dongjoon-hyun/SPARK-50014. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/infra/base/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/infra/base/Dockerfile b/dev/infra/base/Dockerfile index 1edeed775880b..4313a61db5bf2 100644 --- a/dev/infra/base/Dockerfile +++ b/dev/infra/base/Dockerfile @@ -148,7 +148,7 @@ RUN apt-get update && apt-get install -y \ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13 # TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13 RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this -RUN python3.13 -m pip install lxml numpy>=2.1 && \ +RUN python3.13 -m pip install grpcio==1.67.0 grpcio-status==1.67.0 lxml numpy>=2.1 && \ python3.13 -m pip cache purge # Remove unused installation packages to free up disk space From eb73143be6fc3c726a8d1dc17e1592a4f31bf16f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Oct 2024 09:31:41 +0900 Subject: [PATCH 034/108] [SPARK-50019][INFRA] Install BASIC_PIP_PKGS except `pyarrow` in Python 3.13 image ### What changes were proposed in this pull request? This PR aims to install `BASIC_PIP_PKGS` except `pyarrow` in Python 3.13 image. ### Why are the changes needed? - https://github.com/apache/spark/actions/runs/11392144577/job/31698382766 ``` Traceback (most recent call last): File "/__w/spark/spark/python/pyspark/sql/pandas/utils.py", line 28, in require_minimum_pandas_version import pandas ModuleNotFoundError: No module named 'pandas' ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual check with the built image of this PR builder. ``` $ docker run -it --rm ghcr.io/dongjoon-hyun/apache-spark-ci-image:master-11392974455 python3.13 -m pip list | grep pandas pandas 2.2.3 ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48528 from dongjoon-hyun/SPARK-50019. Authored-by: Dongjoon Hyun Signed-off-by: Hyukjin Kwon --- dev/infra/base/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/infra/base/Dockerfile b/dev/infra/base/Dockerfile index 4313a61db5bf2..27474f5f12b6b 100644 --- a/dev/infra/base/Dockerfile +++ b/dev/infra/base/Dockerfile @@ -148,7 +148,7 @@ RUN apt-get update && apt-get install -y \ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13 # TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13 RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this -RUN python3.13 -m pip install grpcio==1.67.0 grpcio-status==1.67.0 lxml numpy>=2.1 && \ +RUN python3.13 -m pip install numpy six==1.16.0 pandas==2.2.3 scipy coverage matplotlib openpyxl grpcio==1.67.0 grpcio-status==1.67.0 lxml numpy>=2.1 && \ python3.13 -m pip cache purge # Remove unused installation packages to free up disk space From 78308da2e1e3fd297e16d335a49b3571647e6493 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Fri, 18 Oct 2024 08:53:05 +0800 Subject: [PATCH 035/108] [SPARK-48773][FOLLOW-UP] spark.conf.set should not fail when setting `spark.default.parallelism` ### What changes were proposed in this pull request? spark.session.set should not fail when setting `spark.default.parallelism`. ### Why are the changes needed? This is to fix a behavior change where before `SPARK-48773`, set `spark.default.parallelism` through spark session does not fail and is a no op. ### Does this PR introduce _any_ user-facing change? Yes. before `SPARK-48773`, spark.conf.set("spark.default.parallelism") does not fail and is a no-op. after ``SPARK-48773`, spark.conf.set("spark.default.parallelism") will fail with a `CANNOT_MODIFY_CONFIG` exception. With this followup, we restore the behavior to spark.conf.set("spark.default.parallelism") does not fail and is a no-op. ### How was this patch tested? manually testing. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48526 from amaliujia/SPARK-48773. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../apache/spark/sql/internal/RuntimeConfigImpl.scala | 11 ++++++++++- .../org/apache/spark/sql/RuntimeConfigSuite.scala | 7 +++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index 0ef879387727a..1739b86c8dcb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.annotation.Stable -import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.internal.config.{ConfigEntry, DEFAULT_PARALLELISM} import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.errors.QueryCompilationErrors @@ -85,6 +85,15 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends } private[sql] def requireNonStaticConf(key: String): Unit = { + // We documented `spark.default.parallelism` by SPARK-48773, however this config + // is actually a static config so now a spark.conf.set("spark.default.parallelism") + // will fail. Before SPARK-48773 it does not, then this becomes a behavior change. + // Technically the current behavior is correct, however it still forms a behavior change. + // To address the change, we need a check here and do not fail on default parallelism + // setting through spark session conf to maintain the same behavior. + if (key == DEFAULT_PARALLELISM.key) { + return + } if (SQLConf.isStaticConfigKey(key)) { throw QueryCompilationErrors.cannotModifyValueOfStaticConfigError(key) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index 009fe55664a2b..c80787c40c487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config +import org.apache.spark.internal.config.DEFAULT_PARALLELISM import org.apache.spark.sql.internal.{RuntimeConfigImpl, SQLConf} import org.apache.spark.sql.internal.SQLConf.CHECKPOINT_LOCATION import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -101,4 +102,10 @@ class RuntimeConfigSuite extends SparkFunSuite { // Get the unset config entry, which should return its defaultValue again. assert(conf.get(key) == SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get) } + + test("SPARK-48773: set spark.default.parallelism does not fail") { + val conf = newConf() + // this set should not fail + conf.set(DEFAULT_PARALLELISM.key, "1") + } } From 75b86667ee7607d3523d7ce75c1022752142a443 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 18 Oct 2024 11:03:32 +0900 Subject: [PATCH 036/108] [SPARK-49829][SS] Fix the bug on the optimization on adding input to state store in stream-stream join ### What changes were proposed in this pull request? The PR proposes to revise the optimization on adding input to state store in stream-stream join, to fix correctness issue. ### Why are the changes needed? Here is the logic of optimization before this PR: https://github.com/apache/spark/blob/039fd13eacb1cef835045e3a60cebf958589e1a2/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala#L671-L677 ``` val isLeftSemiWithMatch = joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty // Add to state store only if both removal predicates do not match, // and the row is not matched for left side of left semi join. val shouldAddToState = !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) && !isLeftSemiWithMatch ``` The criteria of `both removal predicates do not match` means the input is going to be evicted in this batch. I'm not sure about the coverage of this optimization, but there are two major issues with the above optimization: 1) missing to add the input to state store in left side prevents the input on the right side to match with "that" input. Even though the input is going to be evicted in this batch, there could be still inputs on the right side in this batch which can match with that input. 2) missing to add the input to state store prevents that input to produce unmatched (null-outer) output, as we produce unmatched output during the eviction of state. Worth noting that `state watermark != watermark for eviction` and eviction we mentioned in above is based on "state watermark". state watermark could be either 1) equal or earlier than watermark for eviction or 2) "later" than watermark for eviction. ### Does this PR introduce _any_ user-facing change? Yes, there are correctness issues among stream-stream join, especially when the output of the stateful operator is provided as input of stream-stream join. The correctness issue is fixed with the PR. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48297 from HeartSaVioR/SPARK-49829. Lead-authored-by: Jungtaek Lim Co-authored-by: Andrzej Zera Signed-off-by: Jungtaek Lim --- .../StreamingSymmetricHashJoinExec.scala | 39 +++- .../MultiStatefulOperatorsSuite.scala | 54 +++++ .../sql/streaming/StreamingJoinSuite.scala | 205 ++++++++++++++++++ 3 files changed, 291 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index c54917bdb7873..f6213eb27efdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -668,13 +668,38 @@ case class StreamingSymmetricHashJoinExec( private val iteratorNotEmpty: Boolean = super.hasNext override def completion(): Unit = { - val isLeftSemiWithMatch = - joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty - // Add to state store only if both removal predicates do not match, - // and the row is not matched for left side of left semi join. - val shouldAddToState = - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) && - !isLeftSemiWithMatch + // The criteria of whether the input has to be added into state store or not: + // - Left side: input can be skipped to be added to the state store if it's already matched + // and the join type is left semi. + // For other cases, the input should be added, including the case it's going to be evicted + // in this batch. It hasn't yet evaluated with inputs from right side for this batch. + // Refer to the classdoc of SteramingSymmetricHashJoinExec about how stream-stream join + // works. + // - Right side: for this side, the evaluation with inputs from left side for this batch + // is done at this point. That said, input can be skipped to be added to the state store + // if input is going to be evicted in this batch. Though, input should be added to the + // state store if it's right outer join or full outer join, as unmatched output is + // handled during state eviction. + val isLeftSemiWithMatch = joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty + val shouldAddToState = if (isLeftSemiWithMatch) { + false + } else if (joinSide == LeftSide) { + true + } else { + // joinSide == RightSide + + // if the input is not evicted in this batch (hence need to be persisted) + val isNotEvictingInThisBatch = + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + + isNotEvictingInThisBatch || + // if the input is producing "unmatched row" in this batch + ( + (joinType == RightOuter && !iteratorNotEmpty) || + (joinType == FullOuter && !iteratorNotEmpty) + ) + } + if (shouldAddToState) { joinStateManager.append(key, thisRow, matched = iteratorNotEmpty) updatedStateRowsCount += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index e8ee1e3c33015..980f9f48dcb05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -878,6 +878,60 @@ class MultiStatefulOperatorsSuite testOutputWatermarkInJoin(join3, input1, -40L * 1000 - 1) } + test("SPARK-49829 time window agg per each source followed by stream-stream join") { + val inputStream1 = MemoryStream[Long] + val inputStream2 = MemoryStream[Long] + + val df1 = inputStream1.toDF() + .selectExpr("value", "timestamp_seconds(value) AS ts") + .withWatermark("ts", "5 seconds") + + val df2 = inputStream2.toDF() + .selectExpr("value", "timestamp_seconds(value) AS ts") + .withWatermark("ts", "5 seconds") + + val df1Window = df1.groupBy( + window($"ts", "10 seconds") + ).agg(sum("value").as("sum_df1")) + + val df2Window = df2.groupBy( + window($"ts", "10 seconds") + ).agg(sum("value").as("sum_df2")) + + val joined = df1Window.join(df2Window, "window", "inner") + .selectExpr("CAST(window.end AS long) AS window_end", "sum_df1", "sum_df2") + + // The test verifies the case where both sides produce input as time window (append mode) + // for stream-stream join having join condition for equality of time window. + // Inputs are produced into stream-stream join when the time windows are completed, meaning + // they will be evicted in this batch for stream-stream join as well. (NOTE: join condition + // does not delay the state watermark in stream-stream join). + // Before SPARK-49829, left side does not add the input to state store if it's going to evict + // in this batch, which breaks the match between input from left side and input from right + // side for this batch. + testStream(joined)( + MultiAddData( + (inputStream1, Seq(1L, 2L, 3L, 4L, 5L)), + (inputStream2, Seq(5L, 6L, 7L, 8L, 9L)) + ), + // watermark: 5 - 5 = 0 + CheckNewAnswer(), + MultiAddData( + (inputStream1, Seq(11L, 12L, 13L, 14L, 15L)), + (inputStream2, Seq(15L, 16L, 17L, 18L, 19L)) + ), + // watermark: 15 - 5 = 10 (windows for [0, 10) are completed) + // Before SPARK-49829, the test fails because this row is not produced. + CheckNewAnswer((10L, 15L, 35L)), + MultiAddData( + (inputStream1, Seq(100L)), + (inputStream2, Seq(101L)) + ), + // watermark: 100 - 5 = 95 (windows for [0, 20) are completed) + CheckNewAnswer((20L, 65L, 85L)) + ) + } + private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index a733d54d275d2..20b627fbb42ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId} @@ -223,6 +224,26 @@ abstract class StreamingJoinSuite (inputStream, select) } + + protected def assertStateStoreRows( + opId: Long, + joinSide: String, + expectedRows: Seq[Row])(projFn: DataFrame => DataFrame): AssertOnQuery = Execute { q => + val checkpointLoc = q.resolvedCheckpointRoot + + // just make sure the query have no leftover data + q.processAllAvailable() + + // By default, it reads the state store from latest committed batch. + val stateStoreDf = spark.read.format("statestore") + .option(StateSourceOptions.JOIN_SIDE, joinSide) + .option(StateSourceOptions.PATH, checkpointLoc) + .option(StateSourceOptions.OPERATOR_ID, opId) + .load() + + val projectedDf = projFn(stateStoreDf) + checkAnswer(projectedDf, expectedRows) + } } @SlowSQLTest @@ -1559,6 +1580,66 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { ) } } + + test("SPARK-49829 left-outer join, input being unmatched is between WM for late event and " + + "WM for eviction") { + + withTempDir { checkpoint => + // This config needs to be set, otherwise no-data batch will be triggered and after + // no-data batch, WM for late event and WM for eviction would be same. + withSQLConf(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false") { + val memoryStream1 = MemoryStream[(String, Int)] + val memoryStream2 = MemoryStream[(String, Int)] + + val data1 = memoryStream1.toDF() + .selectExpr("_1 AS key", "timestamp_seconds(_2) AS eventTime") + .withWatermark("eventTime", "0 seconds") + val data2 = memoryStream2.toDF() + .selectExpr("_1 AS key", "timestamp_seconds(_2) AS eventTime") + .withWatermark("eventTime", "0 seconds") + + val joinedDf = data1.join(data2, Seq("key", "eventTime"), "leftOuter") + .selectExpr("key", "CAST(eventTime AS long) AS eventTime") + + def assertLeftRows(expected: Seq[Row]): AssertOnQuery = { + assertStateStoreRows(0L, "left", expected) { df => + df.selectExpr("value.key", "CAST(value.eventTime AS long)") + } + } + + def assertRightRows(expected: Seq[Row]): AssertOnQuery = { + assertStateStoreRows(0L, "right", expected) { df => + df.selectExpr("value.key", "CAST(value.eventTime AS long)") + } + } + + testStream(joinedDf)( + StartStream(checkpointLocation = checkpoint.getCanonicalPath), + // batch 0 + // WM: late record = 0, eviction = 0 + MultiAddData( + (memoryStream1, Seq(("a", 1), ("b", 2))), + (memoryStream2, Seq(("b", 2), ("c", 1))) + ), + CheckNewAnswer(("b", 2)), + assertLeftRows(Seq(Row("a", 1), Row("b", 2))), + assertRightRows(Seq(Row("b", 2), Row("c", 1))), + // batch 1 + // WM: late record = 0, eviction = 2 + // Before Spark introduces multiple stateful operator, WM for late record was same as + // WM for eviction, hence ("d", 1) was treated as late record. + // With the multiple state operator, ("d", 1) is added in batch 1 but also evicted in + // batch 1. Note that the eviction is happening with state watermark: for this join, + // state watermark = state eviction under join condition. Before SPARK-49829, this + // wasn't producing unmatched row, and it is fixed. + AddData(memoryStream1, ("d", 1)), + CheckNewAnswer(("a", 1), ("d", 1)), + assertLeftRows(Seq()), + assertRightRows(Seq()) + ) + } + } + } } @SlowSQLTest @@ -1966,4 +2047,128 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { assertNumStateRows(total = 9, updated = 4) ) } + + test("SPARK-49829 two chained stream-stream left outer joins among three input streams") { + withSQLConf(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false") { + val memoryStream1 = MemoryStream[(Long, Int)] + val memoryStream2 = MemoryStream[(Long, Int)] + val memoryStream3 = MemoryStream[(Long, Int)] + + val data1 = memoryStream1.toDF() + .selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v1") + .withWatermark("eventTime", "0 seconds") + val data2 = memoryStream2.toDF() + .selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v2") + .withWatermark("eventTime", "0 seconds") + val data3 = memoryStream3.toDF() + .selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v3") + .withWatermark("eventTime", "0 seconds") + + val join = data1 + .join(data2, Seq("eventTime"), "leftOuter") + .join(data3, Seq("eventTime"), "leftOuter") + .selectExpr("CAST(eventTime AS long) AS eventTime", "v1", "v2", "v3") + + def assertLeftRowsFor1stJoin(expected: Seq[Row]): AssertOnQuery = { + assertStateStoreRows(1L, "left", expected) { df => + df.selectExpr("CAST(value.eventTime AS long)", "value.v1") + } + } + + def assertRightRowsFor1stJoin(expected: Seq[Row]): AssertOnQuery = { + assertStateStoreRows(1L, "right", expected) { df => + df.selectExpr("CAST(value.eventTime AS long)", "value.v2") + } + } + + def assertLeftRowsFor2ndJoin(expected: Seq[Row]): AssertOnQuery = { + assertStateStoreRows(0L, "left", expected) { df => + df.selectExpr("CAST(value.eventTime AS long)", "value.v1", "value.v2") + } + } + + def assertRightRowsFor2ndJoin(expected: Seq[Row]): AssertOnQuery = { + assertStateStoreRows(0L, "right", expected) { df => + df.selectExpr("CAST(value.eventTime AS long)", "value.v3") + } + } + + testStream(join)( + // batch 0 + // WM: late event = 0, eviction = 0 + MultiAddData( + (memoryStream1, Seq((20L, 1))), + (memoryStream2, Seq((20L, 1))), + (memoryStream3, Seq((20L, 1))) + ), + CheckNewAnswer((20, 1, 1, 1)), + assertLeftRowsFor1stJoin(Seq(Row(20, 1))), + assertRightRowsFor1stJoin(Seq(Row(20, 1))), + assertLeftRowsFor2ndJoin(Seq(Row(20, 1, 1))), + assertRightRowsFor2ndJoin(Seq(Row(20, 1))), + // batch 1 + // WM: late event = 0, eviction = 20 + MultiAddData( + (memoryStream1, Seq((21L, 2))), + (memoryStream2, Seq((21L, 2))) + ), + CheckNewAnswer(), + assertLeftRowsFor1stJoin(Seq(Row(21, 2))), + assertRightRowsFor1stJoin(Seq(Row(21, 2))), + assertLeftRowsFor2ndJoin(Seq(Row(21, 2, 2))), + assertRightRowsFor2ndJoin(Seq()), + // batch 2 + // WM: late event = 20, eviction = 20 (slowest: inputStream3) + MultiAddData( + (memoryStream1, Seq((22L, 3))), + (memoryStream3, Seq((22L, 3))) + ), + CheckNewAnswer(), + assertLeftRowsFor1stJoin(Seq(Row(21, 2), Row(22, 3))), + assertRightRowsFor1stJoin(Seq(Row(21, 2))), + assertLeftRowsFor2ndJoin(Seq(Row(21, 2, 2))), + assertRightRowsFor2ndJoin(Seq(Row(22, 3))), + // batch 3 + // WM: late event = 20, eviction = 21 (slowest: inputStream2) + AddData(memoryStream1, (23L, 4)), + CheckNewAnswer(Row(21, 2, 2, null)), + assertLeftRowsFor1stJoin(Seq(Row(22, 3), Row(23, 4))), + assertRightRowsFor1stJoin(Seq()), + assertLeftRowsFor2ndJoin(Seq()), + assertRightRowsFor2ndJoin(Seq(Row(22, 3))), + // batch 4 + // WM: late event = 21, eviction = 21 (slowest: inputStream2) + MultiAddData( + (memoryStream1, Seq((24L, 5))), + (memoryStream2, Seq((24L, 5))), + (memoryStream3, Seq((24L, 5))) + ), + CheckNewAnswer(Row(24, 5, 5, 5)), + assertLeftRowsFor1stJoin(Seq(Row(22, 3), Row(23, 4), Row(24, 5))), + assertRightRowsFor1stJoin(Seq(Row(24, 5))), + assertLeftRowsFor2ndJoin(Seq(Row(24, 5, 5))), + assertRightRowsFor2ndJoin(Seq(Row(22, 3), Row(24, 5))), + // batch 5 + // WM: late event = 21, eviction = 24 + // just trigger a new batch with arbitrary data as the original test relies on no-data + // batch, and we need to check with remaining unmatched outputs + AddData(memoryStream1, (100L, 6)), + // Before SPARK-49829, the test fails because (23, 4, null, null) wasn't produced. + // (The assertion of state for left inputs & right inputs weren't included on the test + // before SPARK-49829.) + CheckNewAnswer(Row(22, 3, null, 3), Row(23, 4, null, null)) + ) + + /* + // The collection of the above new answers is the same with below in original test: + val expected = Array( + Row(Timestamp.valueOf("2024-02-10 10:20:00"), 1, 1, 1), + Row(Timestamp.valueOf("2024-02-10 10:21:00"), 2, 2, null), + Row(Timestamp.valueOf("2024-02-10 10:22:00"), 3, null, 3), + Row(Timestamp.valueOf("2024-02-10 10:23:00"), 4, null, null), + Row(Timestamp.valueOf("2024-02-10 10:24:00"), 5, 5, 5), + ) + */ + } + } } From d90145db7b8452deafb73c268774c75644e2cc4b Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 17 Oct 2024 19:52:27 -0700 Subject: [PATCH 037/108] [SPARK-49991][SQL] Make HadoopMapReduceCommitProtocol respect 'mapreduce.output.basename' to generate file names ### What changes were proposed in this pull request? In 'HadoopMapReduceCommitProtocol', task output files are generated ahead instead of calling `org.apache.hadoop.mapreduce.lib.output.FileOutputFormat#getDefaultWorkFile`, which uses the `mapreduce.output.basename` as the prefix of output files. In this pull request, we modify the `HadoopMapReduceCommitProtocol.getFilename` method to also look up this config instead of using the hardcoded 'part'. ### Why are the changes needed? Given a custom file name is a useful feature for users. They can use it to distinguish files added by different engines, on different days, etc. We can also align the usage scenario with other SQL on Hadoop engines for better Hadoop compatibility. ### Does this PR introduce _any_ user-facing change? Yes, a Hadoop configuration 'mapreduce.output.basename' can be used in file datasource output files ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no` Closes #48494 from yaooqinn/SPARK-49991. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../internal/io/HadoopMapReduceCommitProtocol.scala | 3 ++- .../datasources/parquet/ParquetIOSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index f245d2d4e4074..476cddc643954 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -168,7 +168,8 @@ class HadoopMapReduceCommitProtocol( // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, // the file name is fine and won't overflow. val split = taskContext.getTaskAttemptID.getTaskID.getId - f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}" + val basename = taskContext.getConfiguration.get("mapreduce.output.basename", "part") + f"${spec.prefix}$basename-$split%05d-$jobId${spec.suffix}" } override def setupJob(jobContext: JobContext): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 95fb178154929..22839d3f0d251 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1585,6 +1585,17 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } } + + test("SPARK-49991: Respect 'mapreduce.output.basename' to generate file names") { + withTempPath { dir => + withSQLConf("mapreduce.output.basename" -> "apachespark") { + spark.range(1).coalesce(1).write.parquet(dir.getCanonicalPath) + val df = spark.read.parquet(dir.getCanonicalPath) + assert(df.inputFiles.head.contains("apachespark")) + checkAnswer(spark.read.parquet(dir.getCanonicalPath), Row(0)) + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From 82802494fa350ef83fafac338ebbc7d66a8b0f12 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Fri, 18 Oct 2024 12:47:41 +0900 Subject: [PATCH 038/108] [SPARK-49802][SS] Add support for read change feed for map and list types used in stateful processors ### What changes were proposed in this pull request? Add support for read change feed for map and list types used in stateful processors ### Why are the changes needed? Without this change, reading change feed for map and list types is not supported. ### Does this PR introduce _any_ user-facing change? Yes Users can query state using following query: ``` val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "mapState") .option(StateSourceOptions.READ_CHANGE_FEED, true) .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) .load() ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 24 seconds, 422 milliseconds. [info] Total number of tests run: 4 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48274 from anishshri-db/task/SPARK-49802. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../v2/state/StateDataSource.scala | 9 +- .../v2/state/StatePartitionReader.scala | 34 ++- .../v2/state/utils/SchemaUtil.scala | 66 ++-- .../streaming/state/StateStoreChangelog.scala | 1 + ...ateDataSourceTransformWithStateSuite.scala | 286 +++++++++++++++++- 5 files changed, 359 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index edddfbd6ccaef..2a9abfa5d6a50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{J import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, StateVariableType, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} @@ -175,13 +175,6 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, s"State variable $stateVarName is not defined for the transformWithState operator.") } - - // TODO: add support for list and map type - if (sourceOptions.readChangeFeed && - stateVarInfo.head.stateVariableType != StateVariableType.ValueState) { - throw StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED, - StateSourceOptions.STATE_VAR_NAME)) - } } else { // if the operator is transformWithState, then a state variable argument is mandatory if (stateStoreMetadata.size == 1 && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index b925aee5b627a..9e993fbedc304 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -238,7 +238,22 @@ class StateStoreChangeDataPartitionReader( } override lazy val iter: Iterator[InternalRow] = { - changeDataReader.iterator.map(unifyStateChangeDataRow) + if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + keySchema, "key" + ).asInstanceOf[StructType] + val userKeySchema = SchemaUtil.getSchemaAsDataType( + keySchema, "userKey" + ).asInstanceOf[StructType] + changeDataReader.iterator.map { entry => + val groupingKey = entry._2.get(0, groupingKeySchema).asInstanceOf[UnsafeRow] + val userMapKey = entry._2.get(1, userKeySchema).asInstanceOf[UnsafeRow] + createFlattenedRowForMapState(entry._4, entry._1, + groupingKey, userMapKey, entry._3, partition.partition) + } + } else { + changeDataReader.iterator.map(unifyStateChangeDataRow) + } } override def close(): Unit = { @@ -256,4 +271,21 @@ class StateStoreChangeDataPartitionReader( result.update(4, partition.partition) result } + + private def createFlattenedRowForMapState( + batchId: Long, + recordType: RecordType, + groupingKey: UnsafeRow, + userKey: UnsafeRow, + userValue: UnsafeRow, + partition: Int): InternalRow = { + val result = new GenericInternalRow(6) + result.update(0, batchId) + result.update(1, UTF8String.fromString(getRecordTypeAsString(recordType))) + result.update(2, groupingKey) + result.update(3, userKey) + result.update(4, userValue) + result.update(5, partition) + result + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index c337d548fa42b..84eab3356c204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -49,17 +49,17 @@ object SchemaUtil { valueSchema: StructType, transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): StructType = { - if (sourceOptions.readChangeFeed) { + if (transformWithStateVariableInfoOpt.isDefined) { + require(stateStoreColFamilySchemaOpt.isDefined) + generateSchemaForStateVar(transformWithStateVariableInfoOpt.get, + stateStoreColFamilySchemaOpt.get, sourceOptions) + } else if (sourceOptions.readChangeFeed) { new StructType() .add("batch_id", LongType) .add("change_type", StringType) .add("key", keySchema) .add("value", valueSchema) .add("partition_id", IntegerType) - } else if (transformWithStateVariableInfoOpt.isDefined) { - require(stateStoreColFamilySchemaOpt.isDefined) - generateSchemaForStateVar(transformWithStateVariableInfoOpt.get, - stateStoreColFamilySchemaOpt.get, sourceOptions) } else { new StructType() .add("key", keySchema) @@ -233,25 +233,31 @@ object SchemaUtil { "expiration_timestamp_ms" -> classOf[LongType], "partition_id" -> classOf[IntegerType]) - val expectedFieldNames = if (sourceOptions.readChangeFeed) { - Seq("batch_id", "change_type", "key", "value", "partition_id") - } else if (transformWithStateVariableInfoOpt.isDefined) { + val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) { val stateVarInfo = transformWithStateVariableInfoOpt.get val stateVarType = stateVarInfo.stateVariableType stateVarType match { case ValueState => - Seq("key", "value", "partition_id") + if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "value", "partition_id") + } else { + Seq("key", "value", "partition_id") + } case ListState => - if (sourceOptions.flattenCollectionTypes) { + if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "list_element", "partition_id") + } else if (sourceOptions.flattenCollectionTypes) { Seq("key", "list_element", "partition_id") } else { Seq("key", "list_value", "partition_id") } case MapState => - if (sourceOptions.flattenCollectionTypes) { + if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "user_map_key", "user_map_value", "partition_id") + } else if (sourceOptions.flattenCollectionTypes) { Seq("key", "user_map_key", "user_map_value", "partition_id") } else { Seq("key", "map_value", "partition_id") @@ -264,6 +270,8 @@ object SchemaUtil { throw StateDataSourceErrors .internalError(s"Unsupported state variable type $stateVarType") } + } else if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "value", "partition_id") } else { Seq("key", "value", "partition_id") } @@ -286,13 +294,29 @@ object SchemaUtil { stateVarType match { case ValueState => - new StructType() - .add("key", stateStoreColFamilySchema.keySchema) - .add("value", stateStoreColFamilySchema.valueSchema) - .add("partition_id", IntegerType) + if (stateSourceOptions.readChangeFeed) { + new StructType() + .add("batch_id", LongType) + .add("change_type", StringType) + .add("key", stateStoreColFamilySchema.keySchema) + .add("value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } case ListState => - if (stateSourceOptions.flattenCollectionTypes) { + if (stateSourceOptions.readChangeFeed) { + new StructType() + .add("batch_id", LongType) + .add("change_type", StringType) + .add("key", stateStoreColFamilySchema.keySchema) + .add("list_element", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else if (stateSourceOptions.flattenCollectionTypes) { new StructType() .add("key", stateStoreColFamilySchema.keySchema) .add("list_element", stateStoreColFamilySchema.valueSchema) @@ -313,7 +337,15 @@ object SchemaUtil { valueType = stateStoreColFamilySchema.valueSchema ) - if (stateSourceOptions.flattenCollectionTypes) { + if (stateSourceOptions.readChangeFeed) { + new StructType() + .add("batch_id", LongType) + .add("change_type", StringType) + .add("key", groupingKeySchema) + .add("user_map_key", userKeySchema) + .add("user_map_value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else if (stateSourceOptions.flattenCollectionTypes) { new StructType() .add("key", groupingKeySchema) .add("user_map_key", userKeySchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index e89550da37e03..203af9d10217e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -60,6 +60,7 @@ object RecordType extends Enumeration { recordType match { case PUT_RECORD => "update" case DELETE_RECORD => "delete" + case MERGE_RECORD => "append" case _ => throw StateStoreErrors.unsupportedOperationException( "getRecordTypeAsString", recordType.toString) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 84c6eb54681a1..0aa748f7af93d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -309,14 +309,20 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest TimeMode.ProcessingTime(), OutputMode.Update()) - testStream(result, OutputMode.Update())( - StartStream(checkpointLocation = tempDir.getAbsolutePath), + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), AddData(inputData, "a"), AddData(inputData, "b"), - Execute { _ => - // wait for the batch to run since we are using processing time - Thread.sleep(5000) - }, + AdvanceManualClock(5 * 1000), + CheckNewAnswer(("a", "1"), ("b", "1")), + AddData(inputData, "c"), + AdvanceManualClock(30 * 1000), + CheckNewAnswer(("c", "1")), + AddData(inputData, "d"), + AdvanceManualClock(30 * 1000), + CheckNewAnswer(("d", "1")), StopStream ) @@ -333,19 +339,27 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest var count = 0L resultDf.collect().foreach { row => - count = count + 1 - assert(row.getLong(2) > 0) + if (!row.anyNull) { + count = count + 1 + assert(row.getLong(2) > 0) + } } - // verify that 2 state rows are present - assert(count === 2) + // verify that 4 state rows are present + assert(count === 4) val answerDf = stateReaderDf.selectExpr( "change_type", "key.value AS groupingKey", "value.value.value AS valueId", "partition_id") checkAnswer(answerDf, - Seq(Row("update", "a", 1L, 0), Row("update", "b", 1L, 1))) + Seq(Row("update", "a", 1L, 0), + Row("update", "b", 1L, 1), + Row("update", "c", 1L, 2), + Row("delete", "a", null, 0), + Row("delete", "b", null, 1), + Row("update", "d", 1L, 4), + Row("delete", "c", null, 2))) } } } @@ -410,6 +424,53 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } + testWithChangelogCheckpointingEnabled("state data source cdf integration - list state") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + + val inputData = MemoryStream[(String, String)] + val result = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new SessionGroupsStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, ("session1", "group2")), + AddData(inputData, ("session1", "group1")), + AddData(inputData, ("session2", "group1")), + CheckNewAnswer(), + AddData(inputData, ("session3", "group7")), + AddData(inputData, ("session1", "group4")), + CheckNewAnswer(), + StopStream + ) + + val flattenedReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsList") + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + + val resultDf = flattenedReaderDf.selectExpr( + "change_type", + "key.value AS groupingKey", + "list_element.value AS valueList", + "partition_id") + checkAnswer(resultDf, + Seq(Row("append", "session1", "group1", 0), + Row("append", "session1", "group2", 0), + Row("append", "session1", "group4", 0), + Row("append", "session2", "group1", 0), + Row("append", "session3", "group7", 3))) + } + } + } + test("state data source integration - list state and TTL") { withTempDir { tempDir => withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> @@ -498,6 +559,76 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } + testWithChangelogCheckpointingEnabled("state data source cdf integration - list state and TTL") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[(String, String)] + val result = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new SessionGroupsStatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), + AddData(inputData, ("session1", "group2")), + AddData(inputData, ("session1", "group1")), + AddData(inputData, ("session2", "group1")), + AdvanceManualClock(5 * 1000), + CheckNewAnswer(), + AddData(inputData, ("session3", "group7")), + AddData(inputData, ("session1", "group4")), + AdvanceManualClock(30 * 1000), + CheckNewAnswer(), + StopStream + ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL") + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + + val flattenedResultDf = flattenedStateReaderDf + .selectExpr("list_element.ttlExpirationMs AS ttlExpirationMs") + var flattenedCount = 0L + flattenedResultDf.collect().foreach { row => + if (!row.anyNull) { + flattenedCount = flattenedCount + 1 + assert(row.getLong(0) > 0) + } + } + + // verify that 6 state rows are present + assert(flattenedCount === 6) + + val outputDf = flattenedStateReaderDf + .selectExpr( + "change_type", + "key.value AS groupingKey", + "list_element.value.value AS groupId", + "partition_id") + + checkAnswer(outputDf, + Seq(Row("append", "session1", "group1", 0), + Row("append", "session1", "group2", 0), + Row("append", "session1", "group4", 0), + Row("append", "session2", "group1", 0), + Row("append", "session3", "group7", 3), + Row("delete", "session1", null, 0), + Row("delete", "session2", null, 0), + Row("update", "session1", "group4", 0))) + } + } + } + test("state data source integration - map state with single variable") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -563,6 +694,58 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } + testWithChangelogCheckpointingEnabled("state data source cdf integration - " + + "map state with single variable") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputData = MemoryStream[InputMapRow] + val result = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new TestMapStateProcessor(), + TimeMode.None(), + OutputMode.Append()) + testStream(result, OutputMode.Append())( + StartStream(checkpointLocation = tempDir.getCanonicalPath), + AddData(inputData, InputMapRow("k1", "updateValue", ("v1", "10"))), + AddData(inputData, InputMapRow("k1", "exists", ("", ""))), + AddData(inputData, InputMapRow("k2", "exists", ("", ""))), + CheckNewAnswer(("k1", "exists", "true"), ("k2", "exists", "false")), + + AddData(inputData, InputMapRow("k1", "updateValue", ("v2", "5"))), + AddData(inputData, InputMapRow("k2", "updateValue", ("v2", "3"))), + ProcessAllAvailable(), + StopStream + ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + + val outputDf = flattenedStateReaderDf + .selectExpr( + "change_type", + "key.value AS groupingKey", + "user_map_key.value AS mapKey", + "user_map_value.value AS mapValue", + "partition_id") + + checkAnswer(outputDf, + Seq( + Row("update", "k1", "v1", "10", 4L), + Row("update", "k1", "v2", "5", 4L), + Row("update", "k2", "v2", "3", 2L)) + ) + } + } + } + test("state data source integration - map state TTL with single variable") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -648,6 +831,87 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } + testWithChangelogCheckpointingEnabled("state data source cdf integration - " + + "map state TTL with single variable") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputStream = MemoryStream[MapInputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new MapStateTTLProcessor(ttlConfig), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), + AddData(inputStream, + MapInputEvent("k1", "key1", "put", 1), + MapInputEvent("k1", "key2", "put", 2) + ), + AdvanceManualClock(1 * 1000), // batch timestamp: 1000 + CheckNewAnswer(), + AddData(inputStream, + MapInputEvent("k1", "key1", "get", -1), + MapInputEvent("k1", "key2", "get", -1) + ), + AdvanceManualClock(30 * 1000), // batch timestamp: 31000 + CheckNewAnswer( + MapOutputEvent("k1", "key1", 1, isTTLValue = false, -1), + MapOutputEvent("k1", "key2", 2, isTTLValue = false, -1) + ), + // get values from ttl state + AddData(inputStream, + MapInputEvent("k1", "", "get_values_in_ttl_state", -1) + ), + AdvanceManualClock(1 * 1000), // batch timestamp: 32000 + CheckNewAnswer( + MapOutputEvent("k1", "key1", -1, isTTLValue = true, 61000), + MapOutputEvent("k1", "key2", -1, isTTLValue = true, 61000) + ), + AddData(inputStream, + MapInputEvent("k2", "key3", "put", 3) + ), + AdvanceManualClock(30 * 1000), // batch timestamp: 62000 + CheckNewAnswer(), + StopStream + ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + + val outputDf = flattenedStateReaderDf + .selectExpr( + "change_type", + "key.value AS groupingKey", + "user_map_key.value AS mapKey", + "user_map_value.value.value AS mapValue", + "user_map_value.ttlExpirationMs AS ttlTimestamp", + "partition_id") + + checkAnswer(outputDf, + Seq( + Row("update", "k1", "key1", 1, 61000L, 4L), + Row("update", "k1", "key2", 2, 61000L, 4L), + Row("delete", "k1", "key1", null, null, 4L), + Row("delete", "k1", "key2", null, null, 4L), + Row("update", "k2", "key3", 3, 122000L, 2L)) + ) + } + } + } + test("state data source - processing-time timers integration") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, From 5697df70b3b7ffdd5e38c2099515cb73a57a79bc Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Fri, 18 Oct 2024 13:53:10 +0900 Subject: [PATCH 039/108] [SPARK-49411][SS] Communicate State Store Checkpoint ID between driver and stateful operators ### What changes were proposed in this pull request? This is an incremental step to implement RocksDB state store checkpoint format V2. Once conf STATE_STORE_CHECKPOINT_FORMAT_VERSION is set to be higher than version 2, the executor returns checkpointID to the driver (only done for RocksDB). The driver stores is locally. For the next batch, the State Store Checkpoint ID is sent to the executor to be used to load the state store. If the local version of the executor doesn't match the uniqueID, it will reload from the checkpoint. There is no behavior change if the default checkpoint format is used. ### Why are the changes needed? This is an incremental step of the project of a new RocksDB State Store checkpoint format. The new format is to simplify checkpoint mechanism to make it less bug prone, and fix some unexpected query results in rare queries. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new unit test is added to cover format version. And another unit test is added to validate the uniqueID is passed back and force as expected. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47895 from siying/unique_id2. Authored-by: Siying Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/sql/internal/SQLConf.scala | 11 + ...StreamStreamJoinStatePartitionReader.scala | 9 +- .../FlatMapGroupsWithStateExec.scala | 1 + .../streaming/IncrementalExecution.scala | 20 +- .../streaming/MicroBatchExecution.scala | 59 +- .../execution/streaming/StreamExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 34 +- .../StreamingSymmetricHashJoinHelper.scala | 19 +- .../streaming/TransformWithStateExec.scala | 6 +- .../state/HDFSBackedStateStoreProvider.scala | 18 +- .../execution/streaming/state/RocksDB.scala | 60 +- .../state/RocksDBStateStoreProvider.scala | 20 +- .../streaming/state/StateStore.scala | 53 +- .../streaming/state/StateStoreConf.scala | 9 +- .../streaming/state/StateStoreRDD.scala | 4 + .../state/SymmetricHashJoinStateManager.scala | 117 +++- .../execution/streaming/state/package.scala | 2 + .../streaming/statefulOperators.scala | 128 ++++- ...tWithSessionWindowStateIteratorSuite.scala | 5 +- .../streaming/state/MemoryStateStore.scala | 4 + ...sDBStateStoreCheckpointFormatV2Suite.scala | 544 ++++++++++++++++++ .../state/RocksDBStateStoreSuite.scala | 2 +- .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../streaming/state/StateStoreSuite.scala | 36 +- ...eamingSessionWindowStateManagerSuite.scala | 5 +- .../SymmetricHashJoinStateManagerSuite.scala | 5 +- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../sql/streaming/StreamingJoinSuite.scala | 2 +- 28 files changed, 1106 insertions(+), 73 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 08002887135ce..1c9f5e85d1a06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2204,6 +2204,15 @@ object SQLConf { .intConf .createWithDefault(3) + // The feature is still in development, so it is still internal. + val STATE_STORE_CHECKPOINT_FORMAT_VERSION = + buildConf("spark.sql.streaming.stateStore.checkpointFormatVersion") + .internal() + .doc("The version of the approach of doing state store checkpoint") + .version("4.0.0") + .intConf + .createWithDefault(1) + val STATE_STORE_COMPRESSION_CODEC = buildConf("spark.sql.streaming.stateStore.compression.codec") .internal() @@ -5524,6 +5533,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreCompressionCodec: String = getConf(STATE_STORE_COMPRESSION_CODEC) + def stateStoreCheckpointFormatVersion: Int = getConf(STATE_STORE_CHECKPOINT_FORMAT_VERSION) + def checkpointRenamedFileCheck: Boolean = getConf(CHECKPOINT_RENAMEDFILE_CHECK_ENABLED) def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 673ec3414c237..b9adb379e38c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -102,10 +102,15 @@ class StreamStreamJoinStatePartitionReader( private lazy val iter = { if (joinStateManager == null) { + // Here we don't know the StateStoreCheckpointID, so we set it to None in both stateInfo + // and `keyToNumValuesStateStoreCkptId` and `keyWithIndexToValueStateStoreCkptId` passed + // into SymmetricHashJoinStateManager. + // TODO after we persistent the StateStoreCheckpointID to the commit log, we can get it from + // there and pass it in. val stateInfo = StatefulOperatorStateInfo( partition.sourceOptions.stateCheckpointLocation.toString, partition.queryId, partition.sourceOptions.operatorId, - partition.sourceOptions.batchId + 1, -1) + partition.sourceOptions.batchId + 1, -1, None) joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, @@ -114,6 +119,8 @@ class StreamStreamJoinStatePartitionReader( storeConf = storeConf, hadoopConf = hadoopConf.value, partitionId = partition.partition, + keyToNumValuesStateStoreCkptId = None, + keyWithIndexToValueStateStoreCkptId = None, formatVersion, skippedNullValueCount = None, useStateStoreCoordinator = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 766caaab2285e..941e3a9949cf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -243,6 +243,7 @@ trait FlatMapGroupsWithStateExecBase stateManager.stateSchema, NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType), stateInfo.get.storeVersion, + stateInfo.get.getStateStoreCkptId(partitionId).map(_.head), useColumnFamilies = false, storeConf, hadoopConfBroadcast.value.value) val processor = createInputProcessor(store) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index b24b23c61f4d3..4b8bc72b2ed7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.{Map => MutableMap} + import org.apache.hadoop.fs.Path import org.apache.spark.internal.{Logging, MDC} @@ -45,6 +47,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] * plan incrementally. Possibly preserving state in between each execution. + * @param currentStateStoreCkptId checkpoint ID for the latest committed version. It is + * operatorID -> array of checkpointIDs. Array index n + * represents checkpoint ID for the nth shuffle partition. */ class IncrementalExecution( sparkSession: SparkSession, @@ -57,7 +62,9 @@ class IncrementalExecution( val prevOffsetSeqMetadata: Option[OffsetSeqMetadata], val offsetSeqMetadata: OffsetSeqMetadata, val watermarkPropagator: WatermarkPropagator, - val isFirstBatch: Boolean) + val isFirstBatch: Boolean, + val currentStateStoreCkptId: + MutableMap[Long, Array[Array[String]]] = MutableMap[Long, Array[Array[String]]]()) extends QueryExecution(sparkSession, logicalPlan) with Logging { // Modified planner with stateful operations. @@ -126,12 +133,17 @@ class IncrementalExecution( /** Get the state info of the next stateful operator */ private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo( + val operatorId = statefulOperatorId.getAndIncrement() + // TODO When state store checkpoint format V2 is used, after state store checkpoint ID is + // stored to the commit logs, we should assert the ID is not empty if it is not batch 0 + val ret = StatefulOperatorStateInfo( checkpointLocation, runId, - statefulOperatorId.getAndIncrement(), + operatorId, currentBatchId, - numStateStores) + numStateStores, + currentStateStoreCkptId.get(operatorId)) + ret } sealed trait SparkPlanPartialRule { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 053aef6ced3a6..dc141b21780e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} @@ -130,6 +130,11 @@ class MicroBatchExecution( protected var watermarkTracker: WatermarkTracker = _ + // Store checkpointIDs for state store checkpoints to be committed or have been committed to + // the commit log. + // operatorID -> (partitionID -> array of uniqueID) + private val currentStateStoreCkptId = MutableMap[Long, Array[Array[String]]]() + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -836,7 +841,8 @@ class MicroBatchExecution( offsetLog.offsetSeqMetadataForBatchId(execCtx.batchId - 1), execCtx.offsetSeqMetadata, watermarkPropagator, - execCtx.previousContext.isEmpty) + execCtx.previousContext.isEmpty, + currentStateStoreCkptId) execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan } @@ -900,12 +906,59 @@ class MicroBatchExecution( */ protected def markMicroBatchExecutionStart(execCtx: MicroBatchExecutionContext): Unit = {} + /** + * Store the state store checkpoint id for a finishing batch to `currentStateStoreCkptId`, + * which will be retrieved later when the next batch starts. + */ + private def updateStateStoreCkptIdForOperator( + execCtx: MicroBatchExecutionContext, + opId: Long, + checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = { + // TODO validate baseStateStoreCkptId + checkpointInfo.map(_.batchVersion).foreach { v => + assert( + execCtx.batchId == -1 || v == execCtx.batchId + 1, + s"Batch version ${execCtx.batchId} should generate state store checkpoint " + + s"version ${execCtx.batchId + 1} but we see ${v}") + } + val ckptIds = checkpointInfo.map { info => + assert(info.stateStoreCkptId.isDefined) + info.stateStoreCkptId.get + } + currentStateStoreCkptId.put(opId, ckptIds) + } + + /** + * Walk the query plan `latestExecPlan` to find out a StateStoreWriter operator. Retrieve + * the state store checkpoint id from the operator and update it to `currentStateStoreCkptId`. + * @param execCtx information is needed to do some validation. + * @param latestExecPlan the query plan that contains stateful operators where we would + * extract the state store checkpoint id. + */ + private def updateStateStoreCkptId( + execCtx: MicroBatchExecutionContext, + latestExecPlan: SparkPlan): Unit = { + latestExecPlan.collect { + case e: StateStoreWriter => + assert(e.stateInfo.isDefined, "StateInfo should not be empty in StateStoreWriter") + updateStateStoreCkptIdForOperator( + execCtx, + e.stateInfo.get.operatorId, + e.getStateStoreCheckpointInfo()) + } + } + /** * Called after the microbatch has completed execution. It takes care of committing the offset * to commit log and other bookkeeping. */ protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = { - watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan) + val latestExecPlan = execCtx.executionPlan.executedPlan + watermarkTracker.updateWatermark(latestExecPlan) + if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds( + sparkSessionForStream.sessionState.conf)) { + updateStateStoreCkptId(execCtx, latestExecPlan) + } execCtx.reportTimeTaken("commitOffsets") { if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 14adf951f07e8..d8f32a2cb9225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -210,7 +210,7 @@ abstract class StreamExecution( this, s"spark.streaming.${Option(name).getOrElse(id)}") /** Isolated spark session to run the batches with. */ - private val sparkSessionForStream = sparkSession.cloneSession() + protected val sparkSessionForStream = sparkSession.cloneSession() /** * The thread that runs the micro-batches of this stream. Note that this thread must be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index f6213eb27efdd..4bd531b618e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -303,6 +303,9 @@ case class StreamingSymmetricHashJoinExec( val updateStartTimeNs = System.nanoTime val joinedRow = new JoinedRow + assert(stateInfo.isDefined, "State info not defined") + val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( + partitionId, stateInfo.get) val inputSchema = left.output ++ right.output val postJoinFilter = @@ -310,11 +313,11 @@ case class StreamingSymmetricHashJoinExec( val leftSideJoiner = new OneSideHashJoiner( LeftSide, left.output, leftKeys, leftInputIter, condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId, - skippedNullValueCount) + checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys, skippedNullValueCount) val rightSideJoiner = new OneSideHashJoiner( RightSide, right.output, rightKeys, rightInputIter, condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId, - skippedNullValueCount) + checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys, skippedNullValueCount) // Join one side input using the other side's buffered/state rows. Here is how it is done. // @@ -507,6 +510,12 @@ case class StreamingSymmetricHashJoinExec( val rightSideMetrics = rightSideJoiner.commitStateAndGetMetrics() val combinedMetrics = StateStoreMetrics.combine(Seq(leftSideMetrics, rightSideMetrics)) + val checkpointInfo = SymmetricHashJoinStateManager.mergeStateStoreCheckpointInfo( + JoinStateStoreCkptInfo( + leftSideJoiner.getLatestCheckpointInfo(), + rightSideJoiner.getLatestCheckpointInfo())) + setStateStoreCheckpointInfo(checkpointInfo) + // Update SQL metrics numUpdatedStateRows += (leftSideJoiner.numUpdatedStateRows + rightSideJoiner.numUpdatedStateRows) @@ -544,6 +553,7 @@ case class StreamingSymmetricHashJoinExec( * @param stateWatermarkPredicate The state watermark predicate. See * [[StreamingSymmetricHashJoinExec]] for further description of * state watermarks. + * @param oneSideStateInfo Reconstructed state info for this side * @param partitionId A partition ID of source RDD. */ private class OneSideHashJoiner( @@ -555,6 +565,8 @@ case class StreamingSymmetricHashJoinExec( postJoinFilter: (InternalRow) => Boolean, stateWatermarkPredicate: Option[JoinStateWatermarkPredicate], partitionId: Int, + keyToNumValuesStateStoreCkptId: Option[String], + keyWithIndexToValueStateStoreCkptId: Option[String], skippedNullValueCount: Option[SQLMetric]) { // Filter the joined rows based on the given condition. @@ -562,8 +574,18 @@ case class StreamingSymmetricHashJoinExec( Predicate.create(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ private val joinStateManager = new SymmetricHashJoinStateManager( - joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value, - partitionId, stateFormatVersion, skippedNullValueCount) + joinSide = joinSide, + inputValueAttributes = inputAttributes, + joinKeys = joinKeys, + stateInfo = stateInfo, + storeConf = storeConf, + hadoopConf = hadoopConfBcast.value.value, + partitionId = partitionId, + keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId, + keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId, + stateFormatVersion = stateFormatVersion, + skippedNullValueCount = skippedNullValueCount) + private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { @@ -742,6 +764,10 @@ case class StreamingSymmetricHashJoinExec( joinStateManager.metrics } + def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo = { + joinStateManager.getLatestCheckpointInfo() + } + def numUpdatedStateRows: Long = updatedStateRowsCount } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 49e1f5e8ba12a..497e71070a09a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression -import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{StateStoreCheckpointInfo, StateStoreCoordinatorRef, StateStoreProviderId} /** @@ -320,4 +320,21 @@ object StreamingSymmetricHashJoinHelper extends Logging { dataRDD.sparkContext, f, dataRDD, dataRDD2, stateInfo, storeNames, Some(storeCoordinator)) } } + + case class JoinerStateStoreCkptInfo( + keyToNumValues: StateStoreCheckpointInfo, + valueToNumKeys: StateStoreCheckpointInfo) + + case class JoinStateStoreCkptInfo( + left: JoinerStateStoreCkptInfo, + right: JoinerStateStoreCkptInfo) + + case class JoinerStateStoreCheckpointId( + keyToNumValues: Option[String], + valueToNumKeys: Option[String]) + + case class JoinStateStoreCheckpointId( + left: JoinerStateStoreCheckpointId, + right: JoinerStateStoreCheckpointId) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 648a4c3a68b03..cd567f4c74d70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -548,6 +548,7 @@ case class TransformWithStateExec( DUMMY_VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(keyEncoder.schema), version = stateInfo.get.storeVersion, + stateStoreCkptId = stateInfo.get.getStateStoreCkptId(partitionId).map(_.head), useColumnFamilies = true, storeConf = storeConf, hadoopConf = hadoopConfBroadcast.value.value @@ -622,7 +623,7 @@ case class TransformWithStateExec( hadoopConf = hadoopConfBroadcast.value.value, useMultipleValuesPerKey = true) - val store = stateStoreProvider.getStore(0) + val store = stateStoreProvider.getStore(0, None) val outputIterator = f(store) CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator.iterator, { stateStoreProvider.close() @@ -719,7 +720,8 @@ object TransformWithStateExec { queryRunId = UUID.randomUUID(), operatorId = 0, storeVersion = 0, - numPartitions = shufflePartitions + numPartitions = shufflePartitions, + stateStoreCkptIds = None ) new TransformWithStateExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3df63c41dbf97..2f77b2c14b009 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -222,6 +222,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with StateStoreMetrics(mapToUpdate.size(), metricsFromProvider("memoryUsedBytes"), customMetrics) } + override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = { + StateStoreCheckpointInfo(id.partitionId, newVersion, None, None) + } + /** * Whether all updates have been committed */ @@ -255,7 +259,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } /** Get the state store for making updates to create a new `version` of the store. */ - override def getStore(version: Long): StateStore = { + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + if (uniqueId.isDefined) { + throw QueryExecutionErrors.cannotLoadStore(new SparkException( + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " + + "but a state store checkpointID is passed in")) + } val newMap = getLoadedMapForStore(version) logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)} " + log"of ${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update") @@ -263,7 +272,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } /** Get the state store for reading to specific `version` of the store. */ - override def getReadStore(version: Long): ReadStateStore = { + override def getReadStore(version: Long, stateStoreCkptId: Option[String]): ReadStateStore = { val newMap = getLoadedMapForStore(version) logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)} of " + log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly") @@ -330,6 +339,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with storeConf: StateStoreConf, hadoopConf: Configuration, useMultipleValuesPerKey: Boolean = false): Unit = { + assert( + !storeConf.enableStateStoreCheckpointIds, + "HDFS State Store Provider doesn't support checkpointFormatVersion >= 2 " + + s"checkpointFormatVersion ${storeConf.sqlConf.stateStoreCheckpointFormatVersion}") + this.stateStoreId_ = stateStoreId this.keySchema = keySchema this.valueSchema = valueSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 99f8e7b8f36e6..aeac5ea71a2eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -72,7 +72,8 @@ class RocksDB( localRootDir: File = Utils.createTempDir(), hadoopConf: Configuration = new Configuration, loggingId: String = "", - useColumnFamilies: Boolean = false) extends Logging { + useColumnFamilies: Boolean = false, + enableStateStoreCheckpointIds: Boolean = false) extends Logging { import RocksDB._ @@ -158,6 +159,22 @@ class RocksDB( @volatile private var changelogWriter: Option[StateStoreChangelogWriter] = None private val enableChangelogCheckpointing: Boolean = conf.enableChangelogCheckpointing @volatile private var loadedVersion = -1L // -1 = nothing valid is loaded + + // variables to manage checkpoint ID. Once a checkpointing finishes, it needs to return + // `lastCommittedStateStoreCkptId` as the committed checkpointID, as well as + // `lastCommitBasedStateStoreCkptId` as the checkpontID of the previous version that is based on. + // `loadedStateStoreCkptId` is the checkpointID for the current live DB. After the batch finishes + // and checkpoint finishes, it will turn into `lastCommitBasedStateStoreCkptId`. + // `sessionStateStoreCkptId` store an ID to be used for future checkpoints. It will be used as + // `lastCommittedStateStoreCkptId` after the checkpoint is committed. It will be reused until + // we have to use a new one. We have to update `sessionStateStoreCkptId` if we reload a previous + // batch version, as we would have to use a new checkpointID for re-committing a version. + // The reusing is to help debugging but is not required for the algorithm to work. + private var lastCommitBasedStateStoreCkptId: Option[String] = None + private var lastCommittedStateStoreCkptId: Option[String] = None + private var loadedStateStoreCkptId: Option[String] = None + private var sessionStateStoreCkptId: Option[String] = None + @volatile private var numKeysOnLoadedVersion = 0L @volatile private var numKeysOnWritingVersion = 0L @volatile private var fileManagerMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS @@ -266,13 +283,18 @@ class RocksDB( * Note that this will copy all the necessary file from DFS to local disk as needed, * and possibly restart the native RocksDB instance. */ - def load(version: Long, readOnly: Boolean = false): RocksDB = { + def load( + version: Long, + stateStoreCkptId: Option[String] = None, + readOnly: Boolean = false): RocksDB = { assert(version >= 0) acquire(LoadStore) recordedMetrics = None logInfo(log"Loading ${MDC(LogKeys.VERSION_NUM, version)}") try { - if (loadedVersion != version) { + if (loadedVersion != version || + (enableStateStoreCheckpointIds && stateStoreCkptId.isDefined && + (loadedStateStoreCkptId.isEmpty || stateStoreCkptId.get != loadedStateStoreCkptId.get))) { closeDB(ignoreException = false) val latestSnapshotVersion = fileManager.getLatestSnapshotVersion(version) rocksDBFileMapping.currentVersion = latestSnapshotVersion @@ -311,6 +333,12 @@ class RocksDB( numKeysOnLoadedVersion = numKeysOnWritingVersion fileManagerMetrics = fileManager.latestLoadCheckpointMetrics } + if (enableStateStoreCheckpointIds) { + lastCommitBasedStateStoreCkptId = None + loadedStateStoreCkptId = stateStoreCkptId + sessionStateStoreCkptId = Some(java.util.UUID.randomUUID.toString) + } + lastCommittedStateStoreCkptId = None if (conf.resetStatsOnLoad) { nativeStats.reset } @@ -318,6 +346,10 @@ class RocksDB( } catch { case t: Throwable => loadedVersion = -1 // invalidate loaded data + lastCommitBasedStateStoreCkptId = None + lastCommittedStateStoreCkptId = None + loadedStateStoreCkptId = None + sessionStateStoreCkptId = None throw t } if (enableChangelogCheckpointing && !readOnly) { @@ -672,6 +704,11 @@ class RocksDB( numKeysOnLoadedVersion = numKeysOnWritingVersion loadedVersion = newVersion + if (enableStateStoreCheckpointIds) { + lastCommitBasedStateStoreCkptId = loadedStateStoreCkptId + lastCommittedStateStoreCkptId = sessionStateStoreCkptId + loadedStateStoreCkptId = sessionStateStoreCkptId + } commitLatencyMs ++= Map( "flush" -> flushTimeMs, "compact" -> compactTimeMs, @@ -708,6 +745,10 @@ class RocksDB( acquire(RollbackStore) numKeysOnWritingVersion = numKeysOnLoadedVersion loadedVersion = -1L + lastCommitBasedStateStoreCkptId = None + lastCommittedStateStoreCkptId = None + loadedStateStoreCkptId = None + sessionStateStoreCkptId = None changelogWriter.foreach(_.abort()) // Make sure changelogWriter gets recreated next time. changelogWriter = None @@ -778,6 +819,19 @@ class RocksDB( /** Get the write buffer manager and cache */ def getWriteBufferManagerAndCache(): (WriteBufferManager, Cache) = (writeBufferManager, lruCache) + /** + * Called by RocksDBStateStoreProvider to retrieve the checkpoint information to be + * passed back to the stateful operator. It will return the information for the latest + * state store checkpointing. + */ + def getLatestCheckpointInfo(partitionId: Int): StateStoreCheckpointInfo = { + StateStoreCheckpointInfo( + partitionId, + loadedVersion, + lastCommittedStateStoreCkptId, + lastCommitBasedStateStoreCkptId) + } + /** Get current instantaneous statistics */ private def metrics: RocksDBMetrics = { import HistogramType._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 870ed79ec1747..1fc6ab5910c6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -305,6 +305,11 @@ private[sql] class RocksDBStateStoreProvider } } + override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = { + val checkpointInfo = rocksDB.getLatestCheckpointInfo(id.partitionId) + checkpointInfo + } + override def hasCommitted: Boolean = state == COMMITTED override def toString: String = { @@ -380,12 +385,14 @@ private[sql] class RocksDBStateStoreProvider override def stateStoreId: StateStoreId = stateStoreId_ - override def getStore(version: Long): StateStore = { + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { try { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } - rocksDB.load(version) + rocksDB.load( + version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None) new RocksDBStateStore(version) } catch { @@ -400,12 +407,15 @@ private[sql] class RocksDBStateStoreProvider } } - override def getReadStore(version: Long): StateStore = { + override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = { try { if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } - rocksDB.load(version, true) + rocksDB.load( + version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + readOnly = true) new RocksDBStateStore(version) } catch { @@ -456,7 +466,7 @@ private[sql] class RocksDBStateStoreProvider val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, - useColumnFamilies) + useColumnFamilies, storeConf.enableStateStoreCheckpointIds) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6e616cc71a80c..72bc3ca33054d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -190,6 +190,17 @@ trait StateStore extends ReadStateStore { /** Current metrics of the state store */ def metrics: StateStoreMetrics + /** + * Return information on recently generated checkpoints + * The information should only be usable when checkpoint format version 2 is used and + * underlying state store supports it. + * If it is not the case, the method can return a dummy result. The result eventually won't + * be sent to the driver, but not all the stateful operator is able to figure out whether + * the function should be called to now. They would anyway call it and pass it to + * StatefulOperator.setStateStoreCheckpointInfo(), where it will be ignored. + * */ + def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo + /** * Whether all updates have been committed */ @@ -233,6 +244,21 @@ case class StateStoreMetrics( memoryUsedBytes: Long, customMetrics: Map[StateStoreCustomMetric, Long]) +/** + * State store checkpoint information, used to pass checkpointing information from executors + * to the driver after execution. + * @param stateStoreCkptId The checkpoint ID for a checkpoint at `batchVersion`. This is used to + * identify the checkpoint + * @param baseStateStoreCkptId The checkpoint ID for `batchVersion` - 1, that is used to finish this + * batch. This is used to validate the batch is processed based on the + * correct checkpoint. + */ +case class StateStoreCheckpointInfo( + partitionId: Int, + batchVersion: Long, + stateStoreCkptId: Option[String], + baseStateStoreCkptId: Option[String]) + object StateStoreMetrics { def combine(allMetrics: Seq[StateStoreMetrics]): StateStoreMetrics = { val distinctCustomMetrics = allMetrics.flatMap(_.customMetrics.keys).distinct @@ -353,6 +379,11 @@ case class RangeKeyScanStateEncoderSpec( * version of the data can be accessed. It is the responsible of the provider to populate * this store with context information like the schema of keys and values, etc. * + * If the checkpoint format version 2 is used, an additional argument `checkpointID` may be + * provided as part of `getStore(version, checkpointID)`. The provider needs to guarantee + * that the loaded version is of this unique ID. It needs to load the version for this specific + * ID from the checkpoint if needed. + * * - After the streaming query is stopped, the created provider instances are lazily disposed off. */ trait StateStoreProvider { @@ -394,17 +425,23 @@ trait StateStoreProvider { /** Called when the provider instance is unloaded from the executor */ def close(): Unit - /** Return an instance of [[StateStore]] representing state data of the given version */ - def getStore(version: Long): StateStore + /** + * Return an instance of [[StateStore]] representing state data of the given version. + * If `stateStoreCkptId` is provided, the instance also needs to match the ID. + * */ + def getStore( + version: Long, + stateStoreCkptId: Option[String] = None): StateStore /** - * Return an instance of [[ReadStateStore]] representing state data of the given version. + * Return an instance of [[ReadStateStore]] representing state data of the given version + * and uniqueID if provided. * By default it will return the same instance as getStore(version) but wrapped to prevent * modification. Providers can override and return optimized version of [[ReadStateStore]] * based on the fact the instance will be only used for reading. */ - def getReadStore(version: Long): ReadStateStore = - new WrappedReadStateStore(getStore(version)) + def getReadStore(version: Long, uniqueId: Option[String] = None): ReadStateStore = + new WrappedReadStateStore(getStore(version, uniqueId)) /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } @@ -704,6 +741,7 @@ object StateStore extends Logging { valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, version: Long, + stateStoreCkptId: Option[String], useColumnFamilies: Boolean, storeConf: StateStoreConf, hadoopConf: Configuration, @@ -713,7 +751,7 @@ object StateStore extends Logging { } val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey) - storeProvider.getReadStore(version) + storeProvider.getReadStore(version, stateStoreCkptId) } /** Get or create a store associated with the id. */ @@ -723,6 +761,7 @@ object StateStore extends Logging { valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, version: Long, + stateStoreCkptId: Option[String], useColumnFamilies: Boolean, storeConf: StateStoreConf, hadoopConf: Configuration, @@ -732,7 +771,7 @@ object StateStore extends Logging { } val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey) - storeProvider.getStore(version) + storeProvider.getStore(version, stateStoreCkptId) } private def getStateStoreProvider( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index e199e1a4765e0..c8af395e996d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.streaming.state +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ class StateStoreConf( - @transient private val sqlConf: SQLConf, + @transient private[state] val sqlConf: SQLConf, val extraOptions: Map[String, String] = Map.empty) extends Serializable { @@ -82,6 +83,12 @@ class StateStoreConf( /** The interval of maintenance tasks. */ val maintenanceInterval = sqlConf.streamingMaintenanceInterval + /** + * When creating new state store checkpoint, which format version to use. + */ + val enableStateStoreCheckpointIds = + StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf) + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 133b4ab1cce3c..67c889283a1b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -72,6 +72,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( queryRunId: UUID, operatorId: Long, storeVersion: Long, + stateStoreCkptIds: Option[Array[Array[String]]], keySchema: StructType, valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, @@ -90,6 +91,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( val inputIter = dataRDD.iterator(partition, ctxt) val store = StateStore.getReadOnly( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, + stateStoreCkptIds.map(_.apply(partition.index).head), useColumnFamilies, storeConf, hadoopConfBroadcast.value.value) storeReadFunction(store, inputIter) } @@ -107,6 +109,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( queryRunId: UUID, operatorId: Long, storeVersion: Long, + uniqueId: Option[Array[Array[String]]], keySchema: StructType, valueSchema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec, @@ -126,6 +129,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( val inputIter = dataRDD.iterator(partition, ctxt) val store = StateStore.get( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, + uniqueId.map(_.apply(partition.index).head), useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 4de3170f5db33..6b00418f8fb53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.StatefulOpStateStoreCheckpointInfo import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.types.{BooleanType, LongType, StructField, StructType} import org.apache.spark.util.NextIterator @@ -85,6 +86,8 @@ class SymmetricHashJoinStateManager( storeConf: StateStoreConf, hadoopConf: Configuration, partitionId: Int, + keyToNumValuesStateStoreCkptId: Option[String], + keyWithIndexToValueStateStoreCkptId: Option[String], stateFormatVersion: Int, skippedNullValueCount: Option[SQLMetric] = None, useStateStoreCoordinator: Boolean = true, @@ -411,6 +414,28 @@ class SymmetricHashJoinStateManager( keyWithIndexToValue.abortIfNeeded() } + /** + * Get state store checkpoint information of the two state stores for this joiner, after + * they finished data processing. + */ + def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo = { + val keyToNumValuesCkptInfo = keyToNumValues.getLatestCheckpointInfo() + val keyWithIndexToValueCkptInfo = keyWithIndexToValue.getLatestCheckpointInfo() + + assert( + keyToNumValuesCkptInfo.partitionId == keyWithIndexToValueCkptInfo.partitionId, + "two state stores in a stream-stream joiner don't return the same partition ID") + assert( + keyToNumValuesCkptInfo.batchVersion == keyWithIndexToValueCkptInfo.batchVersion, + "two state stores in a stream-stream joiner don't return the same batch version") + assert( + keyToNumValuesCkptInfo.stateStoreCkptId.isDefined == + keyWithIndexToValueCkptInfo.stateStoreCkptId.isDefined, + "two state stores in a stream-stream joiner should both return checkpoint ID or not") + + JoinerStateStoreCkptInfo(keyToNumValuesCkptInfo, keyWithIndexToValueCkptInfo) + } + /** Get the combined metrics of all the state stores */ def metrics: StateStoreMetrics = { val keyToNumValuesMetrics = keyToNumValues.metrics @@ -451,7 +476,9 @@ class SymmetricHashJoinStateManager( Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } /** Helper trait for invoking common functionalities of a state store. */ - private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { + private abstract class StateStoreHandler( + stateStoreType: StateStoreType, + stateStoreCkptId: Option[String]) extends Logging { private var stateStoreProvider: StateStoreProvider = _ /** StateStore that the subclasses of this class is going to operate on */ @@ -476,6 +503,10 @@ class SymmetricHashJoinStateManager( def metrics: StateStoreMetrics = stateStore.metrics + def getLatestCheckpointInfo(): StateStoreCheckpointInfo = { + stateStore.getStateStoreCheckpointInfo() + } + /** Get the StateStore with the given schema */ protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { val storeProviderId = StateStoreProviderId( @@ -485,7 +516,8 @@ class SymmetricHashJoinStateManager( "when reading state as data source.") StateStore.get( storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - stateInfo.get.storeVersion, useColumnFamilies = false, storeConf, hadoopConf) + stateInfo.get.storeVersion, stateStoreCkptId, useColumnFamilies = false, + storeConf, hadoopConf) } else { // This class will manage the state store provider by itself. stateStoreProvider = StateStoreProvider.createAndInit( @@ -500,7 +532,7 @@ class SymmetricHashJoinStateManager( stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay] .replayStateFromSnapshot(snapshotStartVersion.get, stateInfo.get.storeVersion) } else { - stateStoreProvider.getStore(stateInfo.get.storeVersion) + stateStoreProvider.getStore(stateInfo.get.storeVersion, stateStoreCkptId) } } logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}") @@ -522,7 +554,8 @@ class SymmetricHashJoinStateManager( /** A wrapper around a [[StateStore]] that stores [key -> number of values]. */ - private class KeyToNumValuesStore extends StateStoreHandler(KeyToNumValuesType) { + private class KeyToNumValuesStore + extends StateStoreHandler(KeyToNumValuesType, keyToNumValuesStateStoreCkptId) { private val longValueSchema = new StructType().add("value", "long") private val longToUnsafeRow = UnsafeProjection.create(longValueSchema) private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema)) @@ -669,7 +702,7 @@ class SymmetricHashJoinStateManager( * state format version - please refer implementations of [[KeyWithIndexToValueRowConverter]]. */ private class KeyWithIndexToValueStore(stateFormatVersion: Int) - extends StateStoreHandler(KeyWithIndexToValueType) { + extends StateStoreHandler(KeyWithIndexToValueType, keyWithIndexToValueStateStoreCkptId) { private val keyWithIndexExprs = keyAttributes :+ Literal(1L) private val keyWithIndexSchema = keySchema.add("index", LongType) @@ -808,6 +841,80 @@ object SymmetricHashJoinStateManager { result } + /** + * Stream-stream join has 4 state stores instead of one. So it will generate 4 different + * checkpoint IDs. The approach we take here is to merge them into one array in the checkpointing + * path. The driver will process this single checkpointID. When it is passed back to the + * executors, they will split it back into 4 IDs and use them to load the state. This function is + * used to merge two checkpoint IDs (each in the form of an array of 1) into one array. + * The merged array is expected to read back by `getStateStoreCheckpointIds()`. + */ + def mergeStateStoreCheckpointInfo(joinCkptInfo: JoinStateStoreCkptInfo): + StatefulOpStateStoreCheckpointInfo = { + assert( + joinCkptInfo.left.keyToNumValues.partitionId == joinCkptInfo.right.keyToNumValues.partitionId, + "state store info returned from two Stream-Stream Join sides have different partition IDs") + assert( + joinCkptInfo.left.keyToNumValues.batchVersion == + joinCkptInfo.right.keyToNumValues.batchVersion, + "state store info returned from two Stream-Stream Join sides have different batch versions") + assert( + joinCkptInfo.left.keyToNumValues.stateStoreCkptId.isDefined == + joinCkptInfo.right.keyToNumValues.stateStoreCkptId.isDefined, + "state store info returned from two Stream-Stream Join sides should both return " + + "checkpoint ID or not") + + val ckptIds = joinCkptInfo.left.keyToNumValues.stateStoreCkptId.map( + Array( + _, + joinCkptInfo.left.valueToNumKeys.stateStoreCkptId.get, + joinCkptInfo.right.keyToNumValues.stateStoreCkptId.get, + joinCkptInfo.right.valueToNumKeys.stateStoreCkptId.get + ) + ) + val baseCkptIds = joinCkptInfo.left.keyToNumValues.baseStateStoreCkptId.map( + Array( + _, + joinCkptInfo.left.valueToNumKeys.baseStateStoreCkptId.get, + joinCkptInfo.right.keyToNumValues.baseStateStoreCkptId.get, + joinCkptInfo.right.valueToNumKeys.baseStateStoreCkptId.get + ) + ) + + StatefulOpStateStoreCheckpointInfo( + joinCkptInfo.left.keyToNumValues.partitionId, + joinCkptInfo.left.keyToNumValues.batchVersion, + ckptIds, + baseCkptIds) + } + + /** + * Stream-stream join has 4 state stores instead of one. So it will generate 4 different + * checkpoint IDs. They are translated from each joiners' state store into an array through + * mergeStateStoreCheckpointInfo(). This function is used to read it back into individual state + * store checkpoint IDs. + * @param partitionId + * @param stateInfo + * @return + */ + def getStateStoreCheckpointIds( + partitionId: Int, + stateInfo: StatefulOperatorStateInfo): JoinStateStoreCheckpointId = { + + val stateStoreCkptIds = stateInfo + .stateStoreCkptIds + .map(_(partitionId)) + .map(_.map(Option(_))) + .getOrElse(Array.fill[Option[String]](4)(None)) + JoinStateStoreCheckpointId( + left = JoinerStateStoreCheckpointId( + keyToNumValues = stateStoreCkptIds(0), + valueToNumKeys = stateStoreCkptIds(1)), + right = JoinerStateStoreCheckpointId( + keyToNumValues = stateStoreCkptIds(2), + valueToNumKeys = stateStoreCkptIds(3))) + } + private sealed trait StateStoreType private case object KeyToNumValuesType extends StateStoreType { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 44e939424f55b..e1a95dd10be74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -79,6 +79,7 @@ package object state { stateInfo.queryRunId, stateInfo.operatorId, stateInfo.storeVersion, + stateInfo.stateStoreCkptIds, keySchema, valueSchema, keyStateEncoderSpec, @@ -118,6 +119,7 @@ package object state { stateInfo.queryRunId, stateInfo.operatorId, stateInfo.storeVersion, + stateInfo.stateStoreCkptIds, keySchema, valueSchema, keyStateEncoderSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3cb41710a22c8..8f800b9f0252c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow @@ -44,19 +45,55 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ -import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} +import org.apache.spark.util.{CollectionAccumulator, CompletionIterator, NextIterator, Utils} -/** Used to identify the state store for a given operator. */ +/** Used to identify the state store for a given operator. + * + * stateStoreCkptIds is used to identify the checkpoint used for a specific stateful operator + * The basic workflow works as following: + * 1. When a stateful operator is created, it passes in the checkpoint IDs for each stateful + * operator through the StatefulOperatorStateInfo. + * 2. When a stateful task starts to execute, it will find the checkpointID for its shuffle + * partition and use it to recover the state store. The ID is eventually passed into + * the StateStore layer and eventually RocksDB State Store, where it is used to make sure + * the it loads the correct checkpoint + * 3. When the stateful task is finishing, after the state store is committed, the checkpoint ID + * is fetched from the state store by calling StateStore.getStateStoreCheckpointInfo() and added + * to the stateStoreCkptIds accumulator by calling + * StateStoreWriter.setStateStoreCheckpointInfo(). + * 4. When ending the batch, MicroBatchExecution calls each stateful operator's + * getStateStoreCheckpointInfo() which aggregates checkpointIDs from different partitions. The + * driver will persistent it into commit logs (not implemented yet). + * 5. When forming the next batch, the driver constructs the StatefulOperatorStateInfo with the + * checkpoint IDs for the previous batch. + * */ case class StatefulOperatorStateInfo( checkpointLocation: String, queryRunId: UUID, operatorId: Long, storeVersion: Long, - numPartitions: Int) { + numPartitions: Int, + stateStoreCkptIds: Option[Array[Array[String]]] = None) { + + def getStateStoreCkptId(partitionId: Int): Option[Array[String]] = { + stateStoreCkptIds.map(_(partitionId)) + } + override def toString(): String = { s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " + - s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]" + s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions] " + + s"stateStoreCkptIds = $stateStoreCkptIds" + } +} + +object StatefulOperatorStateInfo { + /** + * Whether stateo store checkpoint version requires checkpointID to be used. + * @return true if state store checkpointID should be used. + */ + def enableStateStoreCheckpointIds(conf: SQLConf): Boolean = { + conf.stateStoreCheckpointFormatVersion >= 2 } } @@ -108,13 +145,27 @@ case class StatefulOperatorCustomSumMetric(name: String, desc: String) } /** An operator that reads from a StateStore. */ -trait StateStoreReader extends StatefulOperator { +trait StateStoreReader extends StatefulOperator with Logging { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) } +/** + * Used to pass state store checkpoint information to load the correct state store checkpoint for a + * stateful operator. Will be passed from the driver to exeuctors to load state store. + */ +case class StatefulOpStateStoreCheckpointInfo( + partitionId: Int, + batchVersion: Long, + // The checkpoint ID for a checkpoint at `batchVersion`. This is used to identify the checkpoint + stateStoreCkptId: Option[Array[String]], + // The checkpoint ID for `batchVersion` - 1, that is used to finish this batch. This is used + // to validate the batch is processed based on the correct checkpoint. + baseStateStoreCkptId: Option[Array[String]]) + /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: SparkPlan => +trait StateStoreWriter + extends StatefulOperator with PythonSQLMetrics with Logging { self: SparkPlan => /** * Produce the output watermark for given input watermark (ms). @@ -182,6 +233,58 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp } } + /** + * Aggregator used for the executors to pass new state store checkpoints' IDs to driver. + * For the general checkpoint ID workflow, see comments of + * class class [[StatefulOperatorStateInfo]] + */ + val checkpointInfoAccumulator: CollectionAccumulator[StatefulOpStateStoreCheckpointInfo] = { + SparkContext.getActive.map(_.collectionAccumulator[StatefulOpStateStoreCheckpointInfo]).get + } + + /** + * Get aggregated checkpoint ID info for all shuffle partitions + * For the general checkpoint ID workflow, see comments of + * class class [[StatefulOperatorStateInfo]] + */ + def getStateStoreCheckpointInfo(): Array[StatefulOpStateStoreCheckpointInfo] = { + assert( + StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf), + "Should not fetch checkpoint Info if the state store checkpoint IDs are not enabled") + // Multiple entries can be returned for the same partitionID due to task speculative execution + // or task failures. All of them should represent a valid state store checkpoint and we just + // pick one of them. + // In the end, we sort them by partitionID. + val ret = checkpointInfoAccumulator + .value + .asScala + .toSeq + .groupBy(_.partitionId) + .map { + case (key, values) => key -> values.head + } + .toSeq + .sortBy(_._1) + .map(_._2) + .toArray + assert( + ret.length == getStateInfo.numPartitions, + s"ChekpointInfo length: ${ret.length}, numPartitions: ${getStateInfo.numPartitions}") + ret + } + + /** + * The executor reports its state store checkpoint ID, which would be sent back to the driver. + * For the general checkpoint ID workflow, see comments of + * class class [[StatefulOperatorStateInfo]] + */ + protected def setStateStoreCheckpointInfo( + checkpointInfo: StatefulOpStateStoreCheckpointInfo): Unit = { + if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) { + checkpointInfoAccumulator.add(checkpointInfo) + } + } + /** * Get the progress made by this stateful operator after execution. This should be called in * the driver after this SparkPlan has been executed and metrics have been updated. @@ -243,6 +346,19 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp storeMetrics.customMetrics.foreach { case (metric, value) => longMetric(metric.name) += value } + + if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) { + // Set the state store checkpoint information for the driver to collect + val ssInfo = store.getStateStoreCheckpointInfo() + setStateStoreCheckpointInfo( + StatefulOpStateStoreCheckpointInfo( + ssInfo.partitionId, + ssInfo.batchVersion, + ssInfo.stateStoreCkptId.map(Array(_)), + ssInfo.baseStateStoreCkptId.map(Array(_)) + ) + ) + } } private def stateStoreCustomMetrics: Map[String, SQLMetric] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala index fc0c239a5d996..55e22c6771bca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala @@ -209,7 +209,8 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef f: (StreamingSessionWindowStateManager, StateStore) => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val stateInfo = StatefulOperatorStateInfo( + file.getAbsolutePath, UUID.randomUUID, 0, 0, 5, None) val manager = StreamingSessionWindowStateManager.createStateManager( keysWithoutSessionAttributes, @@ -221,7 +222,7 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef val store = StateStore.get( storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema, PrefixKeyScanStateEncoderSpec(manager.getStateKeySchema, manager.getNumColsForPrefixKey), - stateInfo.storeVersion, useColumnFamilies = false, storeConf, new Configuration) + stateInfo.storeVersion, None, useColumnFamilies = false, storeConf, new Configuration) try { f(manager, store) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 6a476635a6dbe..9a04a0c759ac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -74,4 +74,8 @@ class MemoryStateStore extends StateStore() { override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = { throw new UnsupportedOperationException("Doesn't support multiple values per key") } + + override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = { + StateStoreCheckpointInfo(id.partitionId, version + 1, None, None) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala new file mode 100644 index 0000000000000..9ac74eb5b9e8f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -0,0 +1,544 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.scalatest.Tag + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.OutputMode.Update +import org.apache.spark.sql.types.StructType + +object CkptIdCollectingStateStoreWrapper { + // Internal list to hold checkpoint IDs (strings) + private var checkpointInfos: List[StateStoreCheckpointInfo] = List.empty + + // Method to add a string (checkpoint ID) to the list in a synchronized way + def addCheckpointInfo(checkpointID: StateStoreCheckpointInfo): Unit = synchronized { + checkpointInfos = checkpointID :: checkpointInfos + } + + // Method to read the list of checkpoint IDs in a synchronized way + def getStateStoreCheckpointInfos: List[StateStoreCheckpointInfo] = synchronized { + checkpointInfos + } + + def clear(): Unit = synchronized { + checkpointInfos = List.empty + } +} + +case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends StateStore { + + // Implement methods from ReadStateStore (parent trait) + + override def id: StateStoreId = innerStore.id + override def version: Long = innerStore.version + + override def get( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = { + innerStore.get(key, colFamilyName) + } + + override def valuesIterator( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow] = { + innerStore.valuesIterator(key, colFamilyName) + } + + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.prefixScan(prefixKey, colFamilyName) + } + + override def iterator( + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.iterator(colFamilyName) + } + + override def abort(): Unit = innerStore.abort() + + // Implement methods from StateStore (current trait) + + override def removeColFamilyIfExists(colFamilyName: String): Boolean = { + innerStore.removeColFamilyIfExists(colFamilyName) + } + + override def createColFamilyIfAbsent( + colFamilyName: String, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { + innerStore.createColFamilyIfAbsent( + colFamilyName, + keySchema, + valueSchema, + keyStateEncoderSpec, + useMultipleValuesPerKey, + isInternal + ) + } + + override def put( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.put(key, value, colFamilyName) + } + + override def remove( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.remove(key, colFamilyName) + } + + override def merge( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.merge(key, value, colFamilyName) + } + + override def commit(): Long = innerStore.commit() + override def metrics: StateStoreMetrics = innerStore.metrics + override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = { + val ret = innerStore.getStateStoreCheckpointInfo() + CkptIdCollectingStateStoreWrapper.addCheckpointInfo(ret) + ret + } + override def hasCommitted: Boolean = innerStore.hasCommitted +} + +class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { + + val innerProvider = new RocksDBStateStoreProvider() + + // Now, delegate all methods in the wrapper class to the inner object + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): Unit = { + innerProvider.init( + stateStoreId, + keySchema, + valueSchema, + keyStateEncoderSpec, + useColumnFamilies, + storeConfs, + hadoopConf, + useMultipleValuesPerKey + ) + } + + override def stateStoreId: StateStoreId = innerProvider.stateStoreId + + override def close(): Unit = innerProvider.close() + + override def getStore(version: Long, stateStoreCkptId: Option[String] = None): StateStore = { + val innerStateStore = innerProvider.getStore(version, stateStoreCkptId) + CkptIdCollectingStateStoreWrapper(innerStateStore) + } + + override def getReadStore(version: Long, uniqueId: Option[String] = None): ReadStateStore = { + new WrappedReadStateStore( + CkptIdCollectingStateStoreWrapper(innerProvider.getReadStore(version, uniqueId))) + } + + override def doMaintenance(): Unit = innerProvider.doMaintenance() + + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = + innerProvider.supportedCustomMetrics +} + +// TODO add a test case for two of the tasks for the same shuffle partitions to finish and +// return their own state store checkpointID. This can happen because of task retry or +// speculative execution. +class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest + with AlsoTestWithChangelogCheckpointingEnabled { + import testImplicits._ + + val providerClassName = classOf[CkptIdCollectingStateStoreProviderWrapper].getCanonicalName + + + override protected def beforeAll(): Unit = { + super.beforeAll() + } + + override def beforeEach(): Unit = { + CkptIdCollectingStateStoreWrapper.clear() + } + + def testWithCheckpointInfoTracked(testName: String, testTags: Tag*)( + testBody: => Any): Unit = { + super.testWithChangelogCheckpointingEnabled(testName, testTags: _*) { + super.beforeEach() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + testBody + } + // in case tests have any code that needs to execute after every test + super.afterEach() + } + } + + // This test enable checkpoint format V2 without validating the checkpoint ID. Just to make + // sure it doesn't break and return the correct query results. + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + + // Run the stream with changelog checkpointing disabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + + def validateBaseCheckpointInfo(): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // Here we assume for every task, we fetch checkpointID from the N state stores in the same + // order. So we can separate stateStoreCkptId for different stores based on the order inside the + // same (batchId, partitionId) group. + val grouped = checkpointInfoList + .groupBy(info => (info.batchVersion, info.partitionId)) + .values + .flatMap { infos => + infos.zipWithIndex.map { case (info, index) => index -> info } + } + .groupBy(_._1) + .map { + case (_, grouped) => + grouped.map { case (_, info) => info } + } + + grouped.foreach { l => + for { + a <- l + b <- l + if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1 + } { + // if batch version exists, it should be the same as the checkpoint ID of the previous batch + assert(!a.baseStateStoreCkptId.isDefined || b.stateStoreCkptId == a.baseStateStoreCkptId) + } + } + } + + def validateCheckpointInfo( + numBatches: Int, + numStateStores: Int, + batchVersionSet: Set[Long]): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // We have 6 batches, 2 partitions, and 1 state store per batch + assert(checkpointInfoList.size == numBatches * numStateStores * 2) + checkpointInfoList.foreach { l => + assert(l.stateStoreCkptId.isDefined) + if (batchVersionSet.contains(l.batchVersion)) { + assert(l.baseStateStoreCkptId.isDefined) + } + } + assert(checkpointInfoList.count(_.partitionId == 0) == numBatches * numStateStores) + assert(checkpointInfoList.count(_.partitionId == 1) == numBatches * numStateStores) + for (i <- 1 to numBatches) { + assert(checkpointInfoList.count(_.batchVersion == i) == numStateStores * 2) + } + validateBaseCheckpointInfo() + } + + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate ID") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)), + StopStream + ) + + // crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + + validateCheckpointInfo(6, 1, Set(2, 4, 5)) + } + + testWithCheckpointInfoTracked( + s"checkpointFormatVersion2 validate ID with dedup and groupBy") { + withTempDir { checkpointDir => + + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .dropDuplicates("value") // Deduplication operation + .groupBy($"value") // Group-by operation + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((2, 1)), // 3 is deduplicated + StopStream + ) + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((1, 1)), // 2,3 is deduplicated + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 1)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 1)), + StopStream + ) + // Crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch(), // 4 is deduplicated + StopStream + ) + } + validateCheckpointInfo(6, 2, Set(2, 4, 5)) + } + + testWithCheckpointInfoTracked( + s"checkpointFormatVersion2 validate ID for stream-stream join") { + withTempDir { checkpointDir => + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val df1 = inputData1.toDS().toDF("value") + val df2 = inputData2.toDS().toDF("value") + + val joined = df1.join(df2, df1("value") === df2("value")) + + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 3, 2), + AddData(inputData2, 3), + CheckLastBatch((3, 3)), + AddData(inputData2, 2), + // This data will be used after restarting the query + AddData(inputData1, 5), + CheckLastBatch((2, 2)), + StopStream + ) + + // Test recovery. + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 4), + AddData(inputData2, 5), + CheckLastBatch((5, 5)), + AddData(inputData2, 4), + // This data will be used after restarting the query + AddData(inputData1, 7), + CheckLastBatch((4, 4)), + StopStream + ) + + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 6), + AddData(inputData2, 6), + CheckLastBatch((6, 6)), + AddData(inputData2, 7), + CheckLastBatch((7, 7)), + StopStream + ) + } + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // We sometimes add data to both data sources before CheckLastBatch(). They could be picked + // up by one or two batches. There will be at least 6 batches, but less than 12. + assert(checkpointInfoList.size % 8 == 0) + val numBatches = checkpointInfoList.size / 8 + + // We don't pass batch versions that would need base checkpoint IDs because we don't know + // batchIDs for that. We only know that there are 3 batches without it. + validateCheckpointInfo(numBatches, 4, Set()) + assert(CkptIdCollectingStateStoreWrapper + .getStateStoreCheckpointInfos + .count(_.baseStateStoreCkptId.isDefined) == (numBatches - 3) * 8) + } + + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate DropDuplicates") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val deduplicated = inputData + .toDF() + .dropDuplicates("value") + .as[Int] + + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch(3), + AddData(inputData, 3, 2), + CheckLastBatch(2), + AddData(inputData, 3, 2, 1), + CheckLastBatch(1), + StopStream + ) + + // Test recovery + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch(4), + AddData(inputData, 5, 4, 4), + CheckLastBatch(5), + StopStream + ) + + // crash recovery again + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch(7) + ) + } + validateCheckpointInfo(6, 1, Set(2, 3, 5)) + } + + testWithCheckpointInfoTracked( + s"checkpointFormatVersion2 validate FlatMapGroupsWithState") { + withTempDir { checkpointDir => + val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + val count: Int = state.getOption.getOrElse(0) + values.size + state.update(count) + Iterator((key, count)) + } + + val inputData = MemoryStream[Int] + val aggregated = inputData + .toDF() + .toDF("key") + .selectExpr("key") + .as[Int] + .repartition($"key") + .groupByKey(x => x) + .flatMapGroupsWithState(OutputMode.Update, GroupStateTimeout.NoTimeout())(stateFunc) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch((4, 1), (1, 1), (3, 3)), + AddData(inputData, 5, 4, 4), + CheckLastBatch((5, 1), (4, 3)), + StopStream + ) + + // crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch((4, 4), (7, 1)), + AddData (inputData, 5), + CheckLastBatch((5, 2)), + StopStream + ) + } + validateCheckpointInfo(6, 1, Set(2, 4, 6)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 32467a2dd11bf..e1bd9dd38066b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -101,7 +101,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val testSchema = StructType(Seq(StructField("key", StringType, true))) val testStateInfo = StatefulOperatorStateInfo( checkpointLocation = Utils.createTempDir().getAbsolutePath, - queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5) + queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5, None) // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 3127c9f602492..947fccdfce72c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -234,7 +234,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5, None) } private val increment = (store: StateStore, iter: Iterator[(String, Int)]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 2a9944a81cb2a..031f5a8b87641 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -108,7 +108,7 @@ class FakeStateStoreProviderWithMaintenanceError extends StateStoreProvider { override def close(): Unit = {} - override def getStore(version: Long): StateStore = null + override def getStore(version: Long, uniqueId: Option[String]): StateStore = null override def doMaintenance(): Unit = { Thread.currentThread.setUncaughtExceptionHandler(exceptionHandler) @@ -483,7 +483,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] for (i <- 1 to 20) { val store = StateStore.get(storeProviderId1, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - latestStoreVersion, useColumnFamilies = false, storeConf, hadoopConf) + latestStoreVersion, None, useColumnFamilies = false, storeConf, hadoopConf) put(store, "a", 0, i) store.commit() latestStoreVersion += 1 @@ -538,7 +538,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Reload the store and verify StateStore.get(storeProviderId1, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - latestStoreVersion, useColumnFamilies = false, storeConf, hadoopConf) + latestStoreVersion, None, useColumnFamilies = false, storeConf, hadoopConf) assert(StateStore.isLoaded(storeProviderId1)) // If some other executor loads the store, then this instance should be unloaded @@ -551,7 +551,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Reload the store and verify StateStore.get(storeProviderId1, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - latestStoreVersion, useColumnFamilies = false, storeConf, hadoopConf) + latestStoreVersion, None, useColumnFamilies = false, storeConf, hadoopConf) assert(StateStore.isLoaded(storeProviderId1)) // If some other executor loads the store, and when this executor loads other store, @@ -560,7 +560,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty) StateStore.get(storeProviderId2, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, useColumnFamilies = false, storeConf, hadoopConf) + 0, None, useColumnFamilies = false, storeConf, hadoopConf) assert(!StateStore.isLoaded(storeProviderId1)) assert(StateStore.isLoaded(storeProviderId2)) } @@ -597,7 +597,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] for (i <- 1 to 20) { val store = StateStore.get(storeProviderId1, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - latestStoreVersion, useColumnFamilies = false, storeConf, hadoopConf) + latestStoreVersion, None, useColumnFamilies = false, storeConf, hadoopConf) put(store, "a", 0, i) store.commit() latestStoreVersion += 1 @@ -731,7 +731,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - version = 0, useColumnFamilies = false, storeConf, hadoopConf) + version = 0, None, useColumnFamilies = false, storeConf, hadoopConf) } // Put should create a temp file @@ -749,7 +749,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - version = 1, useColumnFamilies = false, storeConf, hadoopConf) + version = 1, None, useColumnFamilies = false, storeConf, hadoopConf) } remove(store1, _._1 == "a") assert(numTempFiles === 1) @@ -765,7 +765,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - version = 2, useColumnFamilies = false, storeConf, hadoopConf) + version = 2, None, useColumnFamilies = false, storeConf, hadoopConf) } store2.commit() assert(numTempFiles === 0) @@ -1473,7 +1473,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] var e = intercept[SparkException] { StateStore.get( storeId, keySchema, valueSchema, - NoPrefixKeyStateEncoderSpec(keySchema), -1, useColumnFamilies = false, + NoPrefixKeyStateEncoderSpec(keySchema), -1, None, useColumnFamilies = false, storeConf, hadoopConf) } checkError( @@ -1488,7 +1488,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 1, useColumnFamilies = false, + 1, None, useColumnFamilies = false, storeConf, hadoopConf) } checkError( @@ -1505,7 +1505,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store0 = StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, useColumnFamilies = false, + 0, None, useColumnFamilies = false, storeConf, hadoopConf) assert(store0.version === 0) put(store0, "a", 0, 1) @@ -1514,7 +1514,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store1 = StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 1, useColumnFamilies = false, + 1, None, useColumnFamilies = false, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) assert(store1.version === 1) @@ -1524,7 +1524,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store0reloaded = StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, useColumnFamilies = false, + 0, None, useColumnFamilies = false, storeConf, hadoopConf) assert(store0reloaded.version === 0) assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty) @@ -1536,7 +1536,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] val store1reloaded = StateStore.get( storeId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 1, useColumnFamilies = false, + 1, None, useColumnFamilies = false, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) assert(store1reloaded.version === 1) @@ -1639,20 +1639,20 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] StateStore.get( provider2Id, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() + 0, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() ) // The following 2 calls to `get` will cause the associated maintenance to fail StateStore.get( provider0Id, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() + 0, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() ) StateStore.get( provider1Id, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - 0, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() + 0, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration() ) // Wait for the maintenance task for all the providers to run: it should happen relatively diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala index 1607d7e699d5d..a28366fbaa147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala @@ -173,7 +173,8 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA f: (StreamingSessionWindowStateManager, StateStore) => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val stateInfo = StatefulOperatorStateInfo( + file.getAbsolutePath, UUID.randomUUID, 0, 0, 5, None) val manager = StreamingSessionWindowStateManager.createStateManager( keysWithoutSessionAttributes, @@ -185,7 +186,7 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA val store = StateStore.get( storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema, PrefixKeyScanStateEncoderSpec(manager.getStateKeySchema, manager.getNumColsForPrefixKey), - stateInfo.storeVersion, useColumnFamilies = false, storeConf, new Configuration) + stateInfo.storeVersion, None, useColumnFamilies = false, storeConf, new Configuration) try { f(manager, store) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 16f3e972c7697..96be142cfd5a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -318,10 +318,11 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter withSQLConf(SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key -> skipNullsForStreamStreamJoins.toString) { val storeConf = new StateStoreConf(spark.sessionState.conf) - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val stateInfo = StatefulOperatorStateInfo( + file.getAbsolutePath, UUID.randomUUID, 0, 0, 5, None) val manager = new SymmetricHashJoinStateManager( LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration, - partitionId = 0, stateFormatVersion, metric) + partitionId = 0, None, None, stateFormatVersion, metric) try { f(manager) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9edd1acaddbc1..761354a05cc86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -1487,7 +1487,7 @@ class TestStateStoreProvider extends StateStoreProvider { override def close(): Unit = { } - override def getStore(version: Long): StateStore = null + override def getStore(version: Long, stateStoreCkptId: Option[String] = None): StateStore = null } /** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 20b627fbb42ba..19ab272827441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -561,7 +561,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val opId = 0 val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextFloat().toString).toString - val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5, None) implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator From ff5e0a11f6a18e08884bc2d53ce8aa2912e33015 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 18 Oct 2024 13:53:53 +0800 Subject: [PATCH 040/108] [SPARK-49558][SQL] Add SQL pipe syntax for LIMIT/OFFSET and ORDER/SORT/CLUSTER/DISTRIBUTE BY ### What changes were proposed in this pull request? This PR adds SQL pipe syntax support for LIMIT/OFFSET and ORDER/SORT/CLUSTER/DISTRIBUTE BY. For example: ``` CREATE TABLE t(x INT, y STRING) USING CSV; INSERT INTO t VALUES (0, 'abc'), (1, 'def'); TABLE t |> ORDER BY x |> LIMIT 1 OFFSET 1 1 def ``` ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds a few unit test cases, but mostly relies on golden file test coverage. I did this to make sure the answers are correct as this feature is implemented and also so we can look at the analyzer output plans to ensure they look right as well. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48413 from dtenedor/pipe-order-by. Authored-by: Daniel Tenedorio Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 11 + .../sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../spark/sql/errors/QueryParsingErrors.scala | 21 ++ .../{PipeSelect.scala => pipeOperators.scala} | 12 + .../sql/catalyst/parser/AstBuilder.scala | 73 +++-- .../analyzer-results/pipe-operators.sql.out | 295 ++++++++++++++++++ .../sql-tests/inputs/pipe-operators.sql | 107 ++++++- .../sql-tests/results/pipe-operators.sql.out | 289 +++++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 11 + 9 files changed, 795 insertions(+), 25 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{PipeSelect.scala => pipeOperators.scala} (84%) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index fdf3cf7ccbeb3..99c91f1f18e86 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3460,6 +3460,12 @@ ], "sqlState" : "42P20" }, + "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS" : { + "message" : [ + " and cannot coexist in the same SQL pipe operator using '|>'. Please separate the multiple result clauses into separate pipe operators and then retry the query again." + ], + "sqlState" : "42000" + }, "MULTIPLE_TIME_TRAVEL_SPEC" : { "message" : [ "Cannot specify time travel in both the time travel clause and options." @@ -4986,6 +4992,11 @@ "Catalog does not support ." ] }, + "CLAUSE_WITH_PIPE_OPERATORS" : { + "message" : [ + "The SQL pipe operator syntax using |> does not support ." + ] + }, "COMBINATION_QUERY_RESULT_CLAUSES" : { "message" : [ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY." diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 9d237f069132a..5eb4c276f39bc 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1512,6 +1512,7 @@ operatorPipeRightSide | sample | joinRelation | operator=(UNION | EXCEPT | SETMINUS | INTERSECT) setQuantifier? right=queryTerm + | queryOrganization ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 53cbf086c96e3..6164f2585f0fe 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -82,6 +82,27 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } + def clausesWithPipeOperatorsUnsupportedError( + ctx: QueryOrganizationContext, + clauses: String): Throwable = { + new ParseException( + errorClass = "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", + messageParameters = Map("clauses" -> clauses), + ctx) + } + + def multipleQueryResultClausesWithPipeOperatorsUnsupportedError( + ctx: QueryOrganizationContext, + clause1: String, + clause2: String): Throwable = { + new ParseException( + errorClass = "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS", + messageParameters = Map( + "clause1" -> clause1, + "clause2" -> clause2), + ctx) + } + def combinationQueryResultClausesUnsupportedError(ctx: QueryOrganizationContext): Throwable = { new ParseException(errorClass = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala similarity index 84% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala index 0b5479cc8f0ee..a0f2198212689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala @@ -45,3 +45,15 @@ case class PipeSelect(child: Expression) child } } + +object PipeOperators { + // These are definitions of query result clauses that can be used with the pipe operator. + val clusterByClause = "CLUSTER BY" + val distributeByClause = "DISTRIBUTE BY" + val limitClause = "LIMIT" + val offsetClause = "OFFSET" + val orderByClause = "ORDER BY" + val sortByClause = "SORT BY" + val sortByDistributeByClause = "SORT BY ... DISTRIBUTE BY ..." + val windowClause = "WINDOW" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3ecb680cf6427..25dd423791005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -428,7 +428,8 @@ class AstBuilder extends DataTypeAstBuilder * Create a top-level plan with Common Table Expressions. */ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { - val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)( + withQueryResultClauses(_, _, forPipeOperators = false)) // Apply CTEs query.optionalMap(ctx.ctes)(withCTE) @@ -491,7 +492,7 @@ class AstBuilder extends DataTypeAstBuilder val selects = ctx.fromStatementBody.asScala.map { body => withFromStatementBody(body, from). // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses) + optionalMap(body.queryOrganization)(withQueryResultClauses(_, _, forPipeOperators = false)) } // If there are multiple SELECT just UNION them together into one query. if (selects.length == 1) { @@ -537,7 +538,8 @@ class AstBuilder extends DataTypeAstBuilder val inserts = ctx.multiInsertQueryBody.asScala.map { body => withInsertInto(body.insertInto, withFromStatementBody(body.fromStatementBody, from). - optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses)) + optionalMap(body.fromStatementBody.queryOrganization)( + withQueryResultClauses(_, _, forPipeOperators = false))) } // If there are multiple INSERTS just UNION them together into one query. @@ -976,31 +978,37 @@ class AstBuilder extends DataTypeAstBuilder /** * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These * clauses determine the shape (ordering/partitioning/rows) of the query result. + * + * If 'forPipeOperators' is true, throws an error if the WINDOW clause is present (since this is + * not currently supported) or if more than one clause is present (this can be useful when parsing + * clauses used with pipe operations which only allow one instance of these clauses each). */ private def withQueryResultClauses( ctx: QueryOrganizationContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + query: LogicalPlan, + forPipeOperators: Boolean): LogicalPlan = withOrigin(ctx) { import ctx._ + var clause = "" // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. val withOrder = if ( !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // ORDER BY ... + clause = PipeOperators.orderByClause Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // SORT BY ... + clause = PipeOperators.sortByClause Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { - // DISTRIBUTE BY ... + clause = PipeOperators.distributeByClause withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { - // SORT BY ... DISTRIBUTE BY ... + clause = PipeOperators.sortByDistributeByClause Sort( sort.asScala.map(visitSortItem).toSeq, global = false, withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { - // CLUSTER BY ... + clause = PipeOperators.clusterByClause val expressions = expressionList(clusterBy) Sort( expressions.map(SortOrder(_, Ascending)), @@ -1014,17 +1022,33 @@ class AstBuilder extends DataTypeAstBuilder } // WINDOWS - val withWindow = withOrder.optionalMap(windowClause)(withWindowClause) + val withWindow = withOrder.optionalMap(windowClause) { + withWindowClause + } + if (forPipeOperators && windowClause != null) { + throw QueryParsingErrors.clausesWithPipeOperatorsUnsupportedError( + ctx, s"the ${PipeOperators.windowClause} clause") + } // OFFSET // - OFFSET 0 is the same as omitting the OFFSET clause val withOffset = withWindow.optional(offset) { + if (forPipeOperators && clause.nonEmpty) { + throw QueryParsingErrors.multipleQueryResultClausesWithPipeOperatorsUnsupportedError( + ctx, clause, PipeOperators.offsetClause) + } + clause = PipeOperators.offsetClause Offset(typedVisit(offset), withWindow) } // LIMIT // - LIMIT ALL is the same as omitting the LIMIT clause withOffset.optional(limit) { + if (forPipeOperators && clause.nonEmpty && clause != PipeOperators.offsetClause) { + throw QueryParsingErrors.multipleQueryResultClausesWithPipeOperatorsUnsupportedError( + ctx, clause, PipeOperators.limitClause) + } + clause = PipeOperators.limitClause Limit(typedVisit(limit), withOffset) } } @@ -5883,6 +5907,18 @@ class AstBuilder extends DataTypeAstBuilder if (!SQLConf.get.getConf(SQLConf.OPERATOR_PIPE_SYNTAX_ENABLED)) { operationNotAllowed("Operator pipe SQL syntax using |>", ctx) } + // This helper function adds a table subquery boundary between the new operator to be added + // (such as a filter or sort) and the input plan if one does not already exist. This helps the + // analyzer behave as if we had added the corresponding SQL clause after a table subquery + // containing the input plan. + def withSubqueryAlias(): LogicalPlan = left match { + case s: SubqueryAlias => + s + case u: UnresolvedRelation => + u + case _ => + SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) + } Option(ctx.selectClause).map { c => withSelectQuerySpecification( ctx = ctx, @@ -5895,18 +5931,7 @@ class AstBuilder extends DataTypeAstBuilder relation = left, isPipeOperatorSelect = true) }.getOrElse(Option(ctx.whereClause).map { c => - // Add a table subquery boundary between the new filter and the input plan if one does not - // already exist. This helps the analyzer behave as if we had added the WHERE clause after a - // table subquery containing the input plan. - val withSubqueryAlias = left match { - case s: SubqueryAlias => - s - case u: UnresolvedRelation => - u - case _ => - SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) - } - withWhereClause(c, withSubqueryAlias) + withWhereClause(c, withSubqueryAlias()) }.getOrElse(Option(ctx.pivotClause()).map { c => if (ctx.unpivotClause() != null) { throw QueryParsingErrors.unpivotWithPivotInFromClauseNotAllowedError(ctx) @@ -5924,7 +5949,9 @@ class AstBuilder extends DataTypeAstBuilder }.getOrElse(Option(ctx.operator).map { c => val all = Option(ctx.setQuantifier()).exists(_.ALL != null) visitSetOperationImpl(left, plan(ctx.right), all, c.getType) - }.get)))))) + }.getOrElse(Option(ctx.queryOrganization).map { c => + withQueryResultClauses(c, withSubqueryAlias(), forPipeOperators = true) + }.get))))))) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index 7fa4ec0514ff0..30f340ca834e0 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -236,6 +236,35 @@ CreateViewCommand `natural_join_test_t3`, select * from values +- LocalRelation [k#x, v3#x] +-- !query +create temporary view windowTestData as select * from values + (null, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 2L, 2.5D, date("2017-08-02"), timestamp_seconds(1502000000), "a"), + (2, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "a"), + (1, null, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "b"), + (2, 3L, 3.3D, date("2017-08-03"), timestamp_seconds(1503000000), "b"), + (3, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "b"), + (null, null, null, null, null, null), + (3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null) + AS testData(val, val_long, val_double, val_date, val_timestamp, cate) +-- !query analysis +CreateViewCommand `windowTestData`, select * from values + (null, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 2L, 2.5D, date("2017-08-02"), timestamp_seconds(1502000000), "a"), + (2, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "a"), + (1, null, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "b"), + (2, 3L, 3.3D, date("2017-08-03"), timestamp_seconds(1503000000), "b"), + (3, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "b"), + (null, null, null, null, null, null), + (3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null) + AS testData(val, val_long, val_double, val_date, val_timestamp, cate), false, false, LocalTempView, UNSUPPORTED, true + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + -- !query table t |> select 1 as x @@ -2022,6 +2051,272 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +table t +|> order by x +-- !query analysis +Sort [x#x ASC NULLS FIRST], true ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> order by x +-- !query analysis +Sort [x#x ASC NULLS FIRST], true ++- SubqueryAlias __auto_generated_subquery_name + +- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +values (0, 'abc') tab(x, y) +|> order by x +-- !query analysis +Sort [x#x ASC NULLS FIRST], true ++- SubqueryAlias tab + +- LocalRelation [x#x, y#x] + + +-- !query +table t +|> order by x +|> limit 1 +-- !query analysis +GlobalLimit 1 ++- LocalLimit 1 + +- SubqueryAlias __auto_generated_subquery_name + +- Sort [x#x ASC NULLS FIRST], true + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x = 1 +|> select y +|> limit 2 offset 1 +-- !query analysis +GlobalLimit 2 ++- LocalLimit 2 + +- Offset 1 + +- SubqueryAlias __auto_generated_subquery_name + +- Project [y#x] + +- Filter (x#x = 1) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x = 1 +|> select y +|> offset 1 +-- !query analysis +Offset 1 ++- SubqueryAlias __auto_generated_subquery_name + +- Project [y#x] + +- Filter (x#x = 1) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> limit all offset 0 +-- !query analysis +Offset 0 ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> distribute by x +-- !query analysis +RepartitionByExpression [x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> cluster by x +-- !query analysis +Sort [x#x ASC NULLS FIRST], false ++- RepartitionByExpression [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> sort by x distribute by x +-- !query analysis +RepartitionByExpression [x#x] ++- Sort [x#x ASC NULLS FIRST], false + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> order by x desc +order by y +-- !query analysis +Sort [y#x ASC NULLS FIRST], true ++- Sort [x#x DESC NULLS LAST], true + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> order by x desc order by x + y +order by y +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'order'", + "hint" : "" + } +} + + +-- !query +table t +|> select 1 + 2 as result +|> order by x +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`x`", + "proposal" : "`result`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 47, + "stopIndex" : 47, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> select 1 + 2 as result +|> distribute by x +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`x`", + "proposal" : "`result`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 52, + "stopIndex" : 52, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> order by x limit 1 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS", + "sqlState" : "42000", + "messageParameters" : { + "clause1" : "ORDER BY", + "clause2" : "LIMIT" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 29, + "fragment" : "order by x limit 1" + } ] +} + + +-- !query +table t +|> order by x sort by x +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + "sqlState" : "0A000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 31, + "fragment" : "order by x sort by x" + } ] +} + + +-- !query +table windowTestData +|> window w as (partition by cte order by val) +|> select cate, sum(val) over w +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", + "sqlState" : "0A000", + "messageParameters" : { + "clauses" : "the WINDOW clause" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 67, + "fragment" : "window w as (partition by cte order by val)" + } ] +} + + +-- !query +table windowTestData +|> window w as (partition by cate order by val) limit 5 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", + "sqlState" : "0A000", + "messageParameters" : { + "clauses" : "the WINDOW clause" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 76, + "fragment" : "window w as (partition by cate order by val) limit 5" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 61890f5cb146d..5e0c502a77e84 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -59,6 +59,18 @@ create temporary view natural_join_test_t2 as select * from values create temporary view natural_join_test_t3 as select * from values ("one", 4), ("two", 5), ("one", 6) as natural_join_test_t3(k, v3); +create temporary view windowTestData as select * from values + (null, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 2L, 2.5D, date("2017-08-02"), timestamp_seconds(1502000000), "a"), + (2, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "a"), + (1, null, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "b"), + (2, 3L, 3.3D, date("2017-08-03"), timestamp_seconds(1503000000), "b"), + (3, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "b"), + (null, null, null, null, null, null), + (3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null) + AS testData(val, val_long, val_double, val_date, val_timestamp, cate); + -- SELECT operators: positive tests. --------------------------------------- @@ -505,7 +517,7 @@ table join_test_t1 jt |> cross join (select * from jt); -- Set operations: positive tests. ------------------------------------ +---------------------------------- -- Union all. table t @@ -560,7 +572,7 @@ table t |> minus table t; -- Set operations: negative tests. ------------------------------------ +---------------------------------- -- The UNION operator requires the same number of columns in the input relations. table t @@ -571,6 +583,97 @@ table t table t |> union all table st; +-- Sorting and repartitioning operators: positive tests. +-------------------------------------------------------- + +-- Order by. +table t +|> order by x; + +-- Order by with a table subquery. +(select * from t) +|> order by x; + +-- Order by with a VALUES list. +values (0, 'abc') tab(x, y) +|> order by x; + +-- Limit. +table t +|> order by x +|> limit 1; + +-- Limit with offset. +table t +|> where x = 1 +|> select y +|> limit 2 offset 1; + +-- Offset is allowed without limit. +table t +|> where x = 1 +|> select y +|> offset 1; + +-- LIMIT ALL and OFFSET 0 are equivalent to no LIMIT or OFFSET clause, respectively. +table t +|> limit all offset 0; + +-- Distribute by. +table t +|> distribute by x; + +-- Cluster by. +table t +|> cluster by x; + +-- Sort and distribute by. +table t +|> sort by x distribute by x; + +-- It is possible to apply a final ORDER BY clause on the result of a query containing pipe +-- operators. +table t +|> order by x desc +order by y; + +-- Sorting and repartitioning operators: negative tests. +-------------------------------------------------------- + +-- Multiple order by clauses are not supported in the same pipe operator. +-- We add an extra "ORDER BY y" clause at the end in this test to show that the "ORDER BY x + y" +-- clause was consumed end the of the final query, not as part of the pipe operator. +table t +|> order by x desc order by x + y +order by y; + +-- The ORDER BY clause may only refer to column names from the previous input relation. +table t +|> select 1 + 2 as result +|> order by x; + +-- The DISTRIBUTE BY clause may only refer to column names from the previous input relation. +table t +|> select 1 + 2 as result +|> distribute by x; + +-- Combinations of multiple ordering and limit clauses are not supported. +table t +|> order by x limit 1; + +-- ORDER BY and SORT BY are not supported at the same time. +table t +|> order by x sort by x; + +-- The WINDOW clause is not supported yet. +table windowTestData +|> window w as (partition by cte order by val) +|> select cate, sum(val) over w; + +-- WINDOW and LIMIT are not supported at the same time. +table windowTestData +|> window w as (partition by cate order by val) limit 5; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 8cbc5357d78b6..64d9d38b36306 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -214,6 +214,24 @@ struct<> +-- !query +create temporary view windowTestData as select * from values + (null, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + (1, 2L, 2.5D, date("2017-08-02"), timestamp_seconds(1502000000), "a"), + (2, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "a"), + (1, null, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "b"), + (2, 3L, 3.3D, date("2017-08-03"), timestamp_seconds(1503000000), "b"), + (3, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "b"), + (null, null, null, null, null, null), + (3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null) + AS testData(val, val_long, val_double, val_date, val_timestamp, cate) +-- !query schema +struct<> +-- !query output + + + -- !query table t |> select 1 as x @@ -1673,6 +1691,277 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +table t +|> order by x +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +(select * from t) +|> order by x +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +values (0, 'abc') tab(x, y) +|> order by x +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> order by x +|> limit 1 +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> where x = 1 +|> select y +|> limit 2 offset 1 +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> where x = 1 +|> select y +|> offset 1 +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> limit all offset 0 +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> distribute by x +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> cluster by x +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> sort by x distribute by x +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> order by x desc +order by y +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> order by x desc order by x + y +order by y +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'order'", + "hint" : "" + } +} + + +-- !query +table t +|> select 1 + 2 as result +|> order by x +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`x`", + "proposal" : "`result`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 47, + "stopIndex" : 47, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> select 1 + 2 as result +|> distribute by x +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`x`", + "proposal" : "`result`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 52, + "stopIndex" : 52, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> order by x limit 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS", + "sqlState" : "42000", + "messageParameters" : { + "clause1" : "ORDER BY", + "clause2" : "LIMIT" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 29, + "fragment" : "order by x limit 1" + } ] +} + + +-- !query +table t +|> order by x sort by x +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + "sqlState" : "0A000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 31, + "fragment" : "order by x sort by x" + } ] +} + + +-- !query +table windowTestData +|> window w as (partition by cte order by val) +|> select cate, sum(val) over w +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", + "sqlState" : "0A000", + "messageParameters" : { + "clauses" : "the WINDOW clause" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 67, + "fragment" : "window w as (partition by cte order by val)" + } ] +} + + +-- !query +table windowTestData +|> window w as (partition by cate order by val) limit 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", + "sqlState" : "0A000", + "messageParameters" : { + "clauses" : "the WINDOW clause" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 76, + "fragment" : "window w as (partition by cate order by val) limit 5" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index fc1c9c6755572..644d73bca5b65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -955,6 +955,17 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { checkExcept("TABLE t |> MINUS DISTINCT TABLE t") checkIntersect("TABLE t |> INTERSECT ALL TABLE t") checkUnion("TABLE t |> UNION ALL TABLE t") + // Sorting and distributing operators. + def checkSort(query: String): Unit = check(query, Seq(SORT)) + def checkRepartition(query: String): Unit = check(query, Seq(REPARTITION_OPERATION)) + def checkLimit(query: String): Unit = check(query, Seq(LIMIT)) + checkSort("TABLE t |> ORDER BY x") + checkSort("TABLE t |> SELECT x |> SORT BY x") + checkLimit("TABLE t |> LIMIT 1") + checkLimit("TABLE t |> LIMIT 2 OFFSET 1") + checkRepartition("TABLE t |> DISTRIBUTE BY x |> WHERE x = 1") + checkRepartition("TABLE t |> CLUSTER BY x |> TABLESAMPLE (100 PERCENT)") + checkRepartition("TABLE t |> SORT BY x DISTRIBUTE BY x") } } } From 2724909656fe5cebe1850f0fd81e32a998eed07e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 18 Oct 2024 14:16:42 +0800 Subject: [PATCH 041/108] [SPARK-50011][INFRA][FOLLOW-UP] Refresh the image cache job ### What changes were proposed in this pull request? Refresh the image cache job ### Why are the changes needed? this job has been broken: https://github.com/apache/spark/actions/runs/11387123331/job/31682246504 ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48533 from zhengruifeng/infra_build_cache. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .github/workflows/build_infra_images_cache.yml | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml index 18419334836b2..18e1e43f36c75 100644 --- a/.github/workflows/build_infra_images_cache.yml +++ b/.github/workflows/build_infra_images_cache.yml @@ -26,7 +26,8 @@ on: - 'master' - 'branch-*' paths: - - 'dev/infra/Dockerfile' + - 'dev/infra/base/Dockerfile' + - 'dev/infra/docs/Dockerfile' - '.github/workflows/build_infra_images_cache.yml' # Create infra image when cutting down branches/tags create: @@ -53,10 +54,21 @@ jobs: id: docker_build uses: docker/build-push-action@v6 with: - context: ./dev/infra/ + context: ./dev/infra/base/ push: true tags: ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }}-static cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }} cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }},mode=max - name: Image digest run: echo ${{ steps.docker_build.outputs.digest }} + - name: Build and push (Documentation) + id: docker_build_docs + uses: docker/build-push-action@v6 + with: + context: ./dev/infra/docs/ + push: true + tags: ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }}-static + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }} + cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }},mode=max + - name: Image digest (Documentation) + run: echo ${{ steps.docker_build_docs.outputs.digest }} From 55c558a9fb0034ad4e5cfbce53511cc80e23a993 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Fri, 18 Oct 2024 08:46:06 +0200 Subject: [PATCH 042/108] [SPARK-50010][SQL] Expand implicit collation mismatch error ### What changes were proposed in this pull request? Include the implicit string collations in the `COLLATION_MISMATCH.IMPLICIT` error message. ### Why are the changes needed? Make the implicit collation mismatch error more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes, implicit collation mismatch error is changed. ### How was this patch tested? Updated existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48495 from uros-db/implicit-collation-mismatch. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 2 +- ...nditions-collation-mismatch-error-class.md | 2 +- .../analysis/CollationTypeCasts.scala | 4 +- .../sql/errors/QueryCompilationErrors.scala | 6 +- .../analyzer-results/collations.sql.out | 105 ++++++++++++++---- .../sql-tests/results/collations.sql.out | 105 ++++++++++++++---- .../sql/CollationStringExpressionsSuite.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 40 +++++-- 8 files changed, 212 insertions(+), 56 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 99c91f1f18e86..752524eab7cc7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -571,7 +571,7 @@ }, "IMPLICIT" : { "message" : [ - "Error occurred due to the mismatch between multiple implicit non-default collations. Use COLLATE function to set the collation explicitly." + "Error occurred due to the mismatch between implicit collations: []. Use COLLATE function to set the collation explicitly." ] } }, diff --git a/docs/sql-error-conditions-collation-mismatch-error-class.md b/docs/sql-error-conditions-collation-mismatch-error-class.md index b6a63d87b36a0..79aaaf00ee47c 100644 --- a/docs/sql-error-conditions-collation-mismatch-error-class.md +++ b/docs/sql-error-conditions-collation-mismatch-error-class.md @@ -36,6 +36,6 @@ Error occurred due to the mismatch between explicit collations: ` ## IMPLICIT -Error occurred due to the mismatch between multiple implicit non-default collations. Use COLLATE function to set the collation explicitly. +Error occurred due to the mismatch between implicit collations: ``. Use COLLATE function to set the collation explicitly. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index ced54e590ecc2..cfde8c9d8aa1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -191,7 +191,9 @@ object CollationTypeCasts extends TypeCoercionRule { .distinct if (implicitTypes.length > 1) { - throw QueryCompilationErrors.implicitCollationMismatchError() + throw QueryCompilationErrors.implicitCollationMismatchError( + implicitTypes.map(t => StringType(t)) + ) } else { implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 5e08d463b9e07..3dc906ed98280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3649,10 +3649,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def implicitCollationMismatchError(): Throwable = { + def implicitCollationMismatchError(implicitTypes: Seq[StringType]): Throwable = { new AnalysisException( errorClass = "COLLATION_MISMATCH.IMPLICIT", - messageParameters = Map.empty + messageParameters = Map( + "implicitTypes" -> implicitTypes.map(toSQLType).mkString(", ") + ) ) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index eed7fa73ab698..6a00d31f64316 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -411,7 +411,10 @@ select str_to_map(text, pairDelim, keyValueDelim) from t4 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -693,7 +696,10 @@ select concat_ws(' ', utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -748,7 +754,10 @@ select elt(2, utf8_binary, utf8_lcase, s) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -803,7 +812,10 @@ select split_part(utf8_binary, utf8_lcase, 3) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -890,7 +902,10 @@ select contains(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -977,7 +992,10 @@ select substring_index(utf8_binary, utf8_lcase, 2) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1064,7 +1082,10 @@ select instr(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1151,7 +1172,10 @@ select find_in_set(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1214,7 +1238,10 @@ select startswith(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1309,7 +1336,10 @@ select translate(utf8_binary, utf8_lcase, '12345') from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1380,7 +1410,10 @@ select replace(utf8_binary, utf8_lcase, 'abc') from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1467,7 +1500,10 @@ select endswith(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1754,7 +1790,10 @@ select overlay(utf8_binary, utf8_lcase, 2) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1913,7 +1952,10 @@ select levenshtein(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2088,7 +2130,10 @@ select rpad(utf8_binary, 8, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2151,7 +2196,10 @@ select lpad(utf8_binary, 8, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2214,7 +2262,10 @@ select locate(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2301,7 +2352,10 @@ select TRIM(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" + } } @@ -2388,7 +2442,10 @@ select BTRIM(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2475,7 +2532,10 @@ select LTRIM(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" + } } @@ -2562,7 +2622,10 @@ select RTRIM(utf8_binary, utf8_lcase) from t5 org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" + } } diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index 9d29a46e5a0ef..fe77f88f2af45 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -453,7 +453,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -766,7 +769,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -873,7 +879,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -980,7 +989,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1121,7 +1133,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1274,7 +1289,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1427,7 +1445,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1580,7 +1601,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1707,7 +1731,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1880,7 +1907,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -1993,7 +2023,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2146,7 +2179,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -2744,7 +2780,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -3064,7 +3103,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -3471,7 +3513,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -3598,7 +3643,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -3725,7 +3773,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -3878,7 +3929,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" + } } @@ -4019,7 +4073,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" + } } @@ -4172,7 +4229,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" + } } @@ -4313,7 +4373,10 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "COLLATION_MISMATCH.IMPLICIT", - "sqlState" : "42P21" + "sqlState" : "42P21", + "messageParameters" : { + "implicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index fe9872ddaf575..9407a0df1ed3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -164,7 +164,9 @@ class CollationStringExpressionsSuite }, condition = "COLLATION_MISMATCH.IMPLICIT", sqlState = "42P21", - parameters = Map.empty + parameters = Map( + "implicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" + ) ) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b6da0b169f050..f1b70959ce082 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -545,7 +545,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { s"WHERE c1 = SUBSTR(COLLATE('a', 'UNICODE'), 0)") }, condition = "COLLATION_MISMATCH.IMPLICIT", - parameters = Map.empty + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE"""" + ) ) // in operator @@ -568,7 +570,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 || c2 FROM $tableName") }, - condition = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE"""" + ) ) @@ -583,7 +588,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 = c3") }, - condition = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE_CI"""" + ) ) // different explicit collations are set @@ -629,7 +637,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 || c3 = 'aa'") }, - condition = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE_CI"""" + ) ) // concat on different implicit collations should succeed, @@ -638,7 +649,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName ORDER BY c1 || c3") }, - condition = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE_CI"""" + ) ) // concat + in @@ -655,14 +669,20 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName WHERE contains(c1||c3, 'a')") }, - condition = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE_CI"""" + ) ) checkError( exception = intercept[AnalysisException] { sql(s"SELECT array('A', 'a' COLLATE UNICODE) == array('b' COLLATE UNICODE_CI)") }, - condition = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE UNICODE_CI"""" + ) ) checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' collate UNICODE_CI)"), @@ -811,7 +831,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"CREATE TABLE $newTableName AS SELECT c1 || c2 FROM $tableName") }, - condition = "COLLATION_MISMATCH.IMPLICIT") + condition = "COLLATION_MISMATCH.IMPLICIT", + parameters = Map( + "implicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE UTF8_LCASE"""" + ) + ) } } } From af0cc8ef39186caaa4e05e5e129e19752b3cae64 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Fri, 18 Oct 2024 08:48:22 +0200 Subject: [PATCH 043/108] [SPARK-50018][SQL] Make AbstractStringType serializable ### What changes were proposed in this pull request? Make `AbstractStringType` class serializable, so that the derived classes can be used in expressions that perform replacement using `Invoke`-like Spark expressions. ### Why are the changes needed? Objects with custom parameters cannot be used as inputTypes unless the underlying class is serializable. For example, `ValidateUTF8` is a string function that uses `StaticInvoke` replacement, and an object derived from `AbstractStringType` as one of the input types. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests suffice. More will be added with appropriate collation support in various Spark expressions. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48527 from uros-db/fix-abstractstringtype. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../apache/spark/sql/internal/types/AbstractStringType.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 3a25bba32b530..49d8bf9e001ab 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} * AbstractStringType is an abstract class for StringType with collation support. */ abstract class AbstractStringType(supportsTrimCollation: Boolean = false) - extends AbstractDataType { + extends AbstractDataType + with Serializable { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType override private[sql] def simpleString: String = "string" From fae0c12fea8866c000d1bf6263fbf5cae8526691 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 18 Oct 2024 16:21:16 +0900 Subject: [PATCH 044/108] [SPARK-49558][SQL][FOLLOW-UP] Run `./build/mvn scalafmt:format` to fix CI ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/48413 which formats the Scala API module ### Why are the changes needed? To fix the CI. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually checked by `./dev/lint-scala` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48537 from HyukjinKwon/followup-fmt. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/errors/QueryParsingErrors.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 6164f2585f0fe..0272d06ee1261 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -97,9 +97,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { clause2: String): Throwable = { new ParseException( errorClass = "MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS", - messageParameters = Map( - "clause1" -> clause1, - "clause2" -> clause2), + messageParameters = Map("clause1" -> clause1, "clause2" -> clause2), ctx) } From 693d008de0e407183a278705dc40e0ca64e49053 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 18 Oct 2024 16:40:06 +0800 Subject: [PATCH 045/108] [SPARK-49808][SQL] Fix a deadlock in subquery execution due to lazy vals ### What changes were proposed in this pull request? 1, Introduce a helper class `Lazy` to replace the lazy vals 2, Fix a deadlock in subquery execution ### Why are the changes needed? we observed a deadlock between `QueryPlan.canonicalized` and `QueryPlan.references`: The main thread `TakeOrderedAndProject.doExecute` is trying to compute `outputOrdering`, it top-down traverse the tree, and requires the lock of `QueryPlan.canonicalized` in the path. In this deadlock, it successfully obtained the lock of `WholeStageCodegenExec` and requires the lock of `HashAggregateExec`; Concurrently, a subquery execution thread is performing code generation and bottom-up traverses the tree via `def consume`, which checks `WholeStageCodegenExec.usedInputs` and refererences a lazy val `QueryPlan.references`. It requires the lock of `QueryPlan.references` in the path. In this deadlock, it successfully obtained the lock of `HashAggregateExec` and requires the lock of `WholeStageCodegenExec`; This is due to Scala's lazy val internally calls this.synchronized on the instance that contains the val. This creates a potential for deadlocks. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually test: before the fix, the deadlock happened twice in first 20 runs; after the fix, the deadlock didn't happen in consecutive 100+ runs ### Was this patch authored or co-authored using generative AI tooling? no Closes #48391 from zhengruifeng/query_plan_lazy_ref. Authored-by: Ruifeng Zheng Signed-off-by: Wenchen Fan --- .../org/apache/spark/util/TransientLazy.scala | 43 ++++++++++++++ .../spark/util/TransientLazySuite.scala | 58 +++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 10 ++-- 3 files changed, 107 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/TransientLazy.scala create mode 100644 core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/TransientLazy.scala b/core/src/main/scala/org/apache/spark/util/TransientLazy.scala new file mode 100644 index 0000000000000..2833ef93669a6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TransientLazy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * Construct to lazily initialize a variable. + * This may be helpful for avoiding deadlocks in certain scenarios. For example, + * a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object. + * b) Thread 2 gets spawned off, and tries to initialize a lazy value on the same parent object + * (in our case, this was the logger). This causes scala to also try to grab a coarse lock on + * the parent object. + * c) If thread 1 waits for thread 2 to join, a deadlock occurs. + * The main difference between this and [[LazyTry]] is that this does not cache failures. + * + * @note + * Scala 3 uses a different implementation of lazy vals which doesn't have this problem. + * Please refer to Lazy + * Vals Initialization for more details. + */ +private[spark] class TransientLazy[T](initializer: => T) extends Serializable { + + @transient + private[this] lazy val value: T = initializer + + def apply(): T = { + value + } +} diff --git a/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala b/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala new file mode 100644 index 0000000000000..c0754ee063d67 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{ByteArrayOutputStream, ObjectOutputStream} + +import org.apache.spark.SparkFunSuite + +class TransientLazySuite extends SparkFunSuite { + + test("TransientLazy val works") { + var test: Option[Object] = None + + val lazyval = new TransientLazy({ + test = Some(new Object()) + test + }) + + // Ensure no initialization happened before the lazy value was dereferenced + assert(test.isEmpty) + + // Ensure the first invocation creates a new object + assert(lazyval() == test && test.isDefined) + + // Ensure the subsequent invocation serves the same object + assert(lazyval() == test && test.isDefined) + } + + test("TransientLazy val is serializable") { + val lazyval = new TransientLazy({ + new Object() + }) + + // Ensure serializable before the dereference + val oos = new ObjectOutputStream(new ByteArrayOutputStream()) + oos.writeObject(lazyval) + + val dereferenced = lazyval() + + // Ensure serializable after the dereference + val oos2 = new ObjectOutputStream(new ByteArrayOutputStream()) + oos2.writeObject(lazyval) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3f417644082c3..9418bf298b293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.trees.TreePatternBits import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.TransientLazy import org.apache.spark.util.collection.BitSet /** @@ -94,10 +95,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. */ - @transient - lazy val references: AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes - } + def references: AttributeSet = _references() + + private val _references = new TransientLazy({ + AttributeSet(expressions) -- producedAttributes + }) /** * Returns true when the all the expressions in the current node as well as all of its children From 25c55a86c39fa6625ca71d24fa3c3aa1f5c9942f Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Fri, 18 Oct 2024 10:41:47 +0200 Subject: [PATCH 046/108] [SPARK-50016][SQL] Improve explicit collation mismatch error ### What changes were proposed in this pull request? Include the explicit string collations in the `COLLATION_MISMATCH.EXPLICIT` error message. ### Why are the changes needed? Make the explicit collation mismatch error more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes, explicit collation mismatch error is changed. ### How was this patch tested? Updated existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48525 from uros-db/explicit-collation-mismatch. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../analysis/CollationTypeCasts.scala | 2 +- .../sql/errors/QueryCompilationErrors.scala | 4 +- .../analyzer-results/collations.sql.out | 42 +++++++++---------- .../sql-tests/results/collations.sql.out | 42 +++++++++---------- .../sql/CollationSQLExpressionsSuite.scala | 4 +- .../spark/sql/CollationSQLRegexpSuite.scala | 4 +- .../sql/CollationStringExpressionsSuite.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 14 +++---- 8 files changed, 61 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index cfde8c9d8aa1b..ee278180ce313 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -175,7 +175,7 @@ object CollationTypeCasts extends TypeCoercionRule { case size if size > 1 => throw QueryCompilationErrors .explicitCollationMismatchError( - explicitTypes.map(t => StringType(t).typeName) + explicitTypes.map(t => StringType(t)) ) // Only implicit or default collations present case 0 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 3dc906ed98280..717ce4253acf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3658,11 +3658,11 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def explicitCollationMismatchError(explicitTypes: Seq[String]): Throwable = { + def explicitCollationMismatchError(explicitTypes: Seq[StringType]): Throwable = { new AnalysisException( errorClass = "COLLATION_MISMATCH.EXPLICIT", messageParameters = Map( - "explicitTypes" -> explicitTypes.map(toSQLId).mkString(", ") + "explicitTypes" -> explicitTypes.map(toSQLType).mkString(", ") ) ) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 6a00d31f64316..739ef49627f2d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -426,7 +426,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -711,7 +711,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -769,7 +769,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -835,7 +835,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -925,7 +925,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1015,7 +1015,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1105,7 +1105,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1195,7 +1195,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1261,7 +1261,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1351,7 +1351,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } @@ -1433,7 +1433,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1523,7 +1523,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1813,7 +1813,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1975,7 +1975,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2153,7 +2153,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2219,7 +2219,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2285,7 +2285,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2375,7 +2375,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } @@ -2465,7 +2465,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2555,7 +2555,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } @@ -2645,7 +2645,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index fe77f88f2af45..e0c5e2d0a4312 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -470,7 +470,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -786,7 +786,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -896,7 +896,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1026,7 +1026,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1170,7 +1170,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1326,7 +1326,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1482,7 +1482,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1638,7 +1638,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1768,7 +1768,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -1924,7 +1924,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } @@ -2060,7 +2060,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2216,7 +2216,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -2817,7 +2817,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -3140,7 +3140,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -3550,7 +3550,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -3680,7 +3680,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -3810,7 +3810,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -3954,7 +3954,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } @@ -4110,7 +4110,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "\"STRING\", \"STRING COLLATE UTF8_LCASE\"" } } @@ -4254,7 +4254,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } @@ -4398,7 +4398,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + "explicitTypes" : "\"STRING COLLATE UTF8_LCASE\", \"STRING\"" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index d568cd77050fd..daef9c772a65a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -1252,7 +1252,9 @@ class CollationSQLExpressionsSuite sql("SELECT mask(collate('ab-CD-12-@$','UNICODE'),collate('X','UNICODE_CI'),'x','0','#')") }, condition = "COLLATION_MISMATCH.EXPLICIT", - parameters = Map("explicitTypes" -> "`string collate UNICODE`, `string collate UNICODE_CI`") + parameters = Map( + "explicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE UNICODE_CI"""" + ) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 87dbbc65a3936..5bb8511d0d935 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -430,7 +430,9 @@ class CollationSQLRegexpSuite sql(s"SELECT regexp_replace(collate('ABCDE','$c1'), '.c.', collate('FFF','$c2'))") }, condition = "COLLATION_MISMATCH.EXPLICIT", - parameters = Map("explicitTypes" -> "`string`, `string collate UTF8_LCASE`") + parameters = Map( + "explicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" + ) ) // Unsupported collations case class RegExpReplaceTestFail(l: String, r: String, c: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 9407a0df1ed3a..6db30a7ed0c6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -177,7 +177,9 @@ class CollationStringExpressionsSuite }, condition = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", - parameters = Map("explicitTypes" -> "`string`, `string collate UTF8_LCASE`") + parameters = Map( + "explicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE"""" + ) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index f1b70959ce082..128d558d6735e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -337,7 +337,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string collate $leftCollationName`, `string collate $rightCollationName`" + s""""STRING COLLATE $leftCollationName", "STRING COLLATE $rightCollationName"""" ) ) // startsWith @@ -351,7 +351,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string collate $leftCollationName`, `string collate $rightCollationName`" + s""""STRING COLLATE $leftCollationName", "STRING COLLATE $rightCollationName"""" ) ) // endsWith @@ -365,7 +365,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string collate $leftCollationName`, `string collate $rightCollationName`" + s""""STRING COLLATE $leftCollationName", "STRING COLLATE $rightCollationName"""" ) ) } @@ -605,7 +605,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`, `string collate UNICODE`" + "explicitTypes" -> """"STRING", "STRING COLLATE UNICODE"""" ) ) @@ -617,7 +617,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string`, `string collate UNICODE`" + "explicitTypes" -> """"STRING", "STRING COLLATE UNICODE"""" ) ) checkError( @@ -627,7 +627,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }, condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( - "explicitTypes" -> "`string collate UNICODE`, `string`" + "explicitTypes" -> """"STRING COLLATE UNICODE", "STRING"""" ) ) @@ -715,7 +715,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sqlState = "42P21", parameters = Map( "explicitTypes" -> - s"`string collate UTF8_LCASE`, `string collate UNICODE`" + s""""STRING COLLATE UTF8_LCASE", "STRING COLLATE UNICODE"""" ) ) From b0414b19a0ea5ffa10067154c00009259419f045 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 18 Oct 2024 10:44:19 +0200 Subject: [PATCH 047/108] Revert "[SPARK-49909][SQL] Fix the pretty name of some expressions" ### What changes were proposed in this pull request? The pr aims to revert https://github.com/apache/spark/pull/48385. This reverts commit 52538f0d9bd1258dc2a0a2ab5bdb953f85d85da9. ### Why are the changes needed? When upgrading spark from `an old version` to `the latest version`, some end-users may rely on the `original schema` (`although it may not be correct`), which can make the `upgrade` very difficult. so, let's first restore it to its original state. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48530 from panbingkun/SPARK-49909_revert. Authored-by: panbingkun Signed-off-by: Max Gekk --- python/pyspark/sql/functions/builtin.py | 80 +++++++++---------- .../expressions/aggregate/collect.scala | 5 +- .../expressions/datetimeExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 3 +- .../expressions/numberFormatExpressions.scala | 7 +- .../expressions/randomExpressions.scala | 8 +- .../function_array_agg.explain | 2 +- .../explain-results/function_curdate.explain | 2 +- .../function_current_database.explain | 2 +- .../explain-results/function_dateadd.explain | 2 +- .../function_random_with_seed.explain | 2 +- .../function_to_varchar.explain | 2 +- .../sql-functions/sql-expression-schema.md | 12 +-- .../analyzer-results/charvarchar.sql.out | 6 +- .../current_database_catalog.sql.out | 2 +- .../analyzer-results/group-by.sql.out | 4 +- .../sql-session-variables.sql.out | 2 +- .../sql-tests/results/charvarchar.sql.out | 6 +- .../results/current_database_catalog.sql.out | 2 +- .../sql-tests/results/group-by.sql.out | 4 +- .../results/subexp-elimination.sql.out | 6 +- 21 files changed, 77 insertions(+), 87 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 67c2e23b40ed8..55da50fd4a5a5 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -4921,44 +4921,44 @@ def array_agg(col: "ColumnOrName") -> Column: >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +------------------------------+ - |sort_array(array_agg(c), true)| - +------------------------------+ - | [1, 1, 2]| - +------------------------------+ + +---------------------------------+ + |sort_array(collect_list(c), true)| + +---------------------------------+ + | [1, 1, 2]| + +---------------------------------+ Example 2: Using array_agg function on a string column >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([["apple"],["apple"],["banana"]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show(truncate=False) - +------------------------------+ - |sort_array(array_agg(c), true)| - +------------------------------+ - |[apple, apple, banana] | - +------------------------------+ + +---------------------------------+ + |sort_array(collect_list(c), true)| + +---------------------------------+ + |[apple, apple, banana] | + +---------------------------------+ Example 3: Using array_agg function on a column with null values >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],[None],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +------------------------------+ - |sort_array(array_agg(c), true)| - +------------------------------+ - | [1, 2]| - +------------------------------+ + +---------------------------------+ + |sort_array(collect_list(c), true)| + +---------------------------------+ + | [1, 2]| + +---------------------------------+ Example 4: Using array_agg function on a column with different data types >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],["apple"],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +------------------------------+ - |sort_array(array_agg(c), true)| - +------------------------------+ - | [1, 2, apple]| - +------------------------------+ + +---------------------------------+ + |sort_array(collect_list(c), true)| + +---------------------------------+ + | [1, 2, apple]| + +---------------------------------+ """ return _invoke_function_over_columns("array_agg", col) @@ -8712,31 +8712,31 @@ def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", 1)).show() - +--------------+ - |dateadd(dt, 1)| - +--------------+ - | 2015-04-09| - +--------------+ + +---------------+ + |date_add(dt, 1)| + +---------------+ + | 2015-04-09| + +---------------+ >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", sf.lit(2))).show() - +--------------+ - |dateadd(dt, 2)| - +--------------+ - | 2015-04-10| - +--------------+ + +---------------+ + |date_add(dt, 2)| + +---------------+ + | 2015-04-10| + +---------------+ >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", -1)).show() - +---------------+ - |dateadd(dt, -1)| - +---------------+ - | 2015-04-07| - +---------------+ + +----------------+ + |date_add(dt, -1)| + +----------------+ + | 2015-04-07| + +----------------+ """ days = _enum_to_value(days) days = lit(days) if isinstance(days, int) else days @@ -10343,11 +10343,11 @@ def current_database() -> Column: Examples -------- >>> spark.range(1).select(current_database()).show() - +------------------+ - |current_database()| - +------------------+ - | default| - +------------------+ + +----------------+ + |current_schema()| + +----------------+ + | default| + +----------------+ """ return _invoke_function("current_database") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 3270c6e87e2cd..adad04e7d749c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.Growable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike @@ -118,8 +118,7 @@ case class CollectList( override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override def prettyName: String = - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("collect_list") + override def prettyName: String = "collect_list" override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { new GenericArrayData(buffer.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index de3501a671eb4..d0c4a53e491d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -150,8 +150,7 @@ case class CurrentDate(timeZoneId: Option[String] = None) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def prettyName: String = - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_date") + override def prettyName: String = "current_date" } // scalastyle:off line.size.limit @@ -330,7 +329,7 @@ case class DateAdd(startDate: Expression, days: Expression) }) } - override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("date_add") + override def prettyName: String = "date_add" override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index bef3bac17ffd2..5f1b3dc0a01ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -202,8 +202,7 @@ object AssertTrue { case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false - override def prettyName: String = - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_database") + override def prettyName: String = "current_schema" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index f2fb735b163e1..0d137a9b8f6e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} @@ -307,10 +307,7 @@ case class ToCharacter(left: Expression, right: Expression) inputTypeCheck } } - - override def prettyName: String = - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("to_char") - + override def prettyName: String = "to_char" override def nullSafeEval(decimal: Any, format: Any): Any = { val input = decimal.asInstanceOf[Decimal] numberFormatter.format(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 16bdaa1f7f708..706dc675d7f24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} @@ -128,12 +128,8 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi } override def flatArguments: Iterator[Any] = Iterator(child) - - override def prettyName: String = - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rand") - override def sql: String = { - s"$prettyName(${if (hideSeed) "" else child.sql})" + s"rand(${if (hideSeed) "" else child.sql})" } override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain index 6668692f6cf1d..102f736c62ef6 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain @@ -1,2 +1,2 @@ -Aggregate [array_agg(a#0, 0, 0) AS array_agg(a)#0] +Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain index be039d62a5494..5305b346c4f2d 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain @@ -1,2 +1,2 @@ -Project [curdate(Some(America/Los_Angeles)) AS curdate()#0] +Project [current_date(Some(America/Los_Angeles)) AS current_date()#0] +- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain index 93dfac524d9a1..481c0a478c8df 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain @@ -1,2 +1,2 @@ -Project [current_database() AS current_database()#0] +Project [current_schema() AS current_schema()#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain index 319428541760d..66325085b9c14 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain @@ -1,2 +1,2 @@ -Project [dateadd(d#0, 2) AS dateadd(d, 2)#0] +Project [date_add(d#0, 2) AS date_add(d, 2)#0] +- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain index 5854d2c7fa6be..81c81e95c2bdd 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain @@ -1,2 +1,2 @@ -Project [random(1) AS random(1)#0] +Project [random(1) AS rand(1)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain index cc5149bfed863..f0d9cacc61ac5 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain @@ -1,2 +1,2 @@ -Project [to_varchar(cast(b#0 as decimal(30,15)), $99.99) AS to_varchar(b, $99.99)#0] +Project [to_char(cast(b#0 as decimal(30,15)), $99.99) AS to_char(b, $99.99)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 79fd25aa3eb14..5ad1380e1fb82 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -99,9 +99,9 @@ | org.apache.spark.sql.catalyst.expressions.Csc | csc | SELECT csc(1) | struct | | org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct> | | org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | SELECT a, b, cume_dist() OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | -| org.apache.spark.sql.catalyst.expressions.CurDateExpressionBuilder | curdate | SELECT curdate() | struct | +| org.apache.spark.sql.catalyst.expressions.CurDateExpressionBuilder | curdate | SELECT curdate() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct | -| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | +| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_schema | SELECT current_schema() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct | @@ -110,7 +110,7 @@ | org.apache.spark.sql.catalyst.expressions.CurrentUser | session_user | SELECT session_user() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentUser | user | SELECT user() | struct | | org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct | -| org.apache.spark.sql.catalyst.expressions.DateAdd | dateadd | SELECT dateadd('2016-07-30', 1) | struct | +| org.apache.spark.sql.catalyst.expressions.DateAdd | dateadd | SELECT dateadd('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | date_diff | SELECT date_diff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct | @@ -264,7 +264,7 @@ | org.apache.spark.sql.catalyst.expressions.RPadExpressionBuilder | rpad | SELECT rpad('hi', 5, '??') | struct | | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | -| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | +| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | | org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | @@ -340,7 +340,7 @@ | org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct | | org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct | | org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_char | SELECT to_char(454, '999') | struct | -| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_varchar | SELECT to_varchar(454, '999') | struct | +| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_varchar | SELECT to_varchar(454, '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToDegrees | degrees | SELECT degrees(3.141592653589793) | struct | | org.apache.spark.sql.catalyst.expressions.ToNumber | to_number | SELECT to_number('454', '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToRadians | radians | SELECT radians(180) | struct | @@ -402,7 +402,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | any | SELECT any(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | bool_or | SELECT bool_or(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | some | SELECT some(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | array_agg | SELECT array_agg(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | +| org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | array_agg | SELECT array_agg(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | collect_list | SELECT collect_list(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet | collect_set | SELECT collect_set(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.Corr | corr | SELECT corr(c1, c2) FROM VALUES (3, 2), (3, 3), (6, 4) as tab(c1, c2) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out index d4bcb8f2ed042..524797015a2f6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out @@ -722,19 +722,19 @@ Project [chr(cast(167 as bigint)) AS chr(167)#x, chr(cast(247 as bigint)) AS chr -- !query SELECT to_varchar(78.12, '$99.99') -- !query analysis -Project [to_varchar(78.12, $99.99) AS to_varchar(78.12, $99.99)#x] +Project [to_char(78.12, $99.99) AS to_char(78.12, $99.99)#x] +- OneRowRelation -- !query SELECT to_varchar(111.11, '99.9') -- !query analysis -Project [to_varchar(111.11, 99.9) AS to_varchar(111.11, 99.9)#x] +Project [to_char(111.11, 99.9) AS to_char(111.11, 99.9)#x] +- OneRowRelation -- !query SELECT to_varchar(12454.8, '99,999.9S') -- !query analysis -Project [to_varchar(12454.8, 99,999.9S) AS to_varchar(12454.8, 99,999.9S)#x] +Project [to_char(12454.8, 99,999.9S) AS to_char(12454.8, 99,999.9S)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out index 2759f5e67507b..1a71594f84932 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out @@ -2,5 +2,5 @@ -- !query select current_database(), current_schema(), current_catalog() -- !query analysis -Project [current_database() AS current_database()#x, current_schema() AS current_schema()#x, current_catalog() AS current_catalog()#x] +Project [current_schema() AS current_schema()#x, current_schema() AS current_schema()#x, current_catalog() AS current_catalog()#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out index 6996eb913a21e..8849aa4452252 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out @@ -1133,7 +1133,7 @@ SELECT FROM VALUES (1), (2), (1) AS tab(col) -- !query analysis -Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, array_agg(col#x, 0, 0) AS array_agg(col)#x] +Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, collect_list(col#x, 0, 0) AS collect_list(col)#x] +- SubqueryAlias tab +- LocalRelation [col#x] @@ -1147,7 +1147,7 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a -- !query analysis -Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, array_agg(b#x, 0, 0) AS array_agg(b)#x] +Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, collect_list(b#x, 0, 0) AS collect_list(b)#x] +- SubqueryAlias v +- LocalRelation [a#x, b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index 8c10d78405751..02e7c39ae83fd 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -776,7 +776,7 @@ Project [NULL AS Expected#x, variablereference(system.session.var1=CAST(NULL AS -- !query DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT CURRENT_DATABASE() -- !query analysis -CreateVariable defaultvalueexpression(cast(current_database() as string), CURRENT_DATABASE()), true +CreateVariable defaultvalueexpression(cast(current_schema() as string), CURRENT_DATABASE()), true +- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index 2960c4ca4f4d4..8aafa25c5caaf 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -1235,7 +1235,7 @@ struct -- !query SELECT to_varchar(78.12, '$99.99') -- !query schema -struct +struct -- !query output $78.12 @@ -1243,7 +1243,7 @@ $78.12 -- !query SELECT to_varchar(111.11, '99.9') -- !query schema -struct +struct -- !query output ##.# @@ -1251,6 +1251,6 @@ struct -- !query SELECT to_varchar(12454.8, '99,999.9S') -- !query schema -struct +struct -- !query output 12,454.8+ diff --git a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out index 7fbe2dfff4db1..67db0adee7f07 100644 --- a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out @@ -2,6 +2,6 @@ -- !query select current_database(), current_schema(), current_catalog() -- !query schema -struct +struct -- !query output default default spark_catalog diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 5d220fc12b78e..d8a9f4c2e11f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1066,7 +1066,7 @@ SELECT FROM VALUES (1), (2), (1) AS tab(col) -- !query schema -struct,array_agg(col):array> +struct,collect_list(col):array> -- !query output [1,2,1] [1,2,1] @@ -1080,7 +1080,7 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a -- !query schema -struct,array_agg(b):array> +struct,collect_list(b):array> -- !query output 1 [4,4] [4,4] 2 [3,4] [3,4] diff --git a/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out b/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out index 28457c0579e95..0f7ff3f107567 100644 --- a/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out @@ -72,7 +72,7 @@ NULL -- !query SELECT from_json(a, 'struct').a + random() > 2, from_json(a, 'struct').b, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].b + + random() > 2 FROM testData -- !query schema -struct<((from_json(a).a + random()) > 2):boolean,from_json(a).b:string,from_json(b)[0].a:int,((from_json(b)[0].b + (+ random())) > 2):boolean> +struct<((from_json(a).a + rand()) > 2):boolean,from_json(a).b:string,from_json(b)[0].a:int,((from_json(b)[0].b + (+ rand())) > 2):boolean> -- !query output NULL NULL 1 true false 2 1 true @@ -84,7 +84,7 @@ true 6 6 true -- !query SELECT if(from_json(a, 'struct').a + random() > 5, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].a + 1) FROM testData -- !query schema -struct<(IF(((from_json(a).a + random()) > 5), from_json(b)[0].a, (from_json(b)[0].a + 1))):int> +struct<(IF(((from_json(a).a + rand()) > 5), from_json(b)[0].a, (from_json(b)[0].a + 1))):int> -- !query output 2 2 @@ -96,7 +96,7 @@ NULL -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(a, 'struct').b + random() > 5 when from_json(a, 'struct').a > 4 then from_json(a, 'struct').b + 1 + random() > 2 else from_json(a, 'struct').b + 2 + random() > 5 end FROM testData -- !query schema -struct 5) THEN ((from_json(a).b + random()) > 5) WHEN (from_json(a).a > 4) THEN (((from_json(a).b + 1) + random()) > 2) ELSE (((from_json(a).b + 2) + random()) > 5) END:boolean> +struct 5) THEN ((from_json(a).b + rand()) > 5) WHEN (from_json(a).a > 4) THEN (((from_json(a).b + 1) + rand()) > 2) ELSE (((from_json(a).b + 2) + rand()) > 5) END:boolean> -- !query output NULL false From 0705e102292323551b3f6dfbb7edabc06d31ceb3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 18 Oct 2024 19:44:10 +0900 Subject: [PATCH 048/108] [SPARK-50007][SQL][SS] Provide default values for metrics on observe API when physical node is lost in executed plan ### What changes were proposed in this pull request? This PR proposes to provide default values for metrics on observe API, when physical node (CollectMetricsExec) is lost in executed plan. This includes the case where logical node (CollectMetrics) is lost during optimization (and it's mostly the case). ### Why are the changes needed? When user defines the metrics via observe API, they expect the metrics to be retrieved via Observation (batch query) or update event of StreamingQueryListener. But when the node (CollectMetrics) is lost in any reason (e.g. subtree is pruned by PruneFilters), Spark does behave like the metrics were not defined, instead of providing default values. When the query runs successfully, user wouldn't expect the metric being bound to the query to be unavailable, hence they missed to guard the code for this case and encountered some issue. Arguably it's lot better to provide default values - when the node is pruned out from optimizer, it is mostly logically equivalent that there were no input being processed with the node (except the bug in analyzer/optimizer/etc which drop the node incorrectly), hence it's valid to just have default value. ### Does this PR introduce _any_ user-facing change? Yes, user can consistently query about metrics being defined with observe API. It is available even with aggressive optimization which drop the CollectMetrics(Exec) node. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48517 from HeartSaVioR/SPARK-50007. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../sql/execution/CollectMetricsExec.scala | 26 ++++++++++++++++--- .../spark/sql/execution/QueryExecution.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 14 ++++++++++ ...ingQueryOptimizationCorrectnessSuite.scala | 24 +++++++++++++++++ 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index 0a487bac77696..19da27611bf6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -98,13 +99,30 @@ object CollectMetricsExec extends AdaptiveSparkPlanHelper { /** * Recursively collect all collected metrics from a query tree. */ - def collect(plan: SparkPlan): Map[String, Row] = { - val metrics = collectWithSubqueries(plan) { + def collect(physicalPlan: SparkPlan, analyzedOpt: Option[LogicalPlan]): Map[String, Row] = { + val collectorsInLogicalPlan = analyzedOpt.map { + _.collectWithSubqueries { + case c: CollectMetrics => c + } + }.getOrElse(Seq.empty[CollectMetrics]) + val metrics = collectWithSubqueries(physicalPlan) { case collector: CollectMetricsExec => Map(collector.name -> collector.collectedMetrics) case tableScan: InMemoryTableScanExec => - CollectMetricsExec.collect(tableScan.relation.cachedPlan) + CollectMetricsExec.collect( + tableScan.relation.cachedPlan, tableScan.relation.cachedPlan.logicalLink) } - metrics.reduceOption(_ ++ _).getOrElse(Map.empty) + val metricsCollected = metrics.reduceOption(_ ++ _).getOrElse(Map.empty) + val initialMetricsForMissing = collectorsInLogicalPlan.flatMap { collector => + if (!metricsCollected.contains(collector.name)) { + val exec = CollectMetricsExec(collector.name, collector.metrics, + EmptyRelationExec(LocalRelation(collector.child.output))) + val initialMetrics = exec.collectedMetrics + Some((collector.name -> initialMetrics)) + } else { + None + } + }.toMap + metricsCollected ++ initialMetricsForMissing } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 6ff2c5d4b9d32..1fa2ebce8ab81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -243,7 +243,7 @@ class QueryExecution( def toRdd: RDD[InternalRow] = lazyToRdd.get /** Get the metrics observed during the execution of the query plan. */ - def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan) + def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan, Some(analyzed)) protected def preparations: Seq[Rule[SparkPlan]] = { QueryExecution.preparations(sparkSession, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e3346684285a9..d123800aa20ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4941,6 +4941,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row(Array(0), Array(0)), Row(Array(1), Array(1)), Row(Array(2), Array(2))) checkAnswer(df, expectedAnswer) } + + test("SPARK-50007: default metrics is provided even observe node is pruned out") { + val namedObservation = Observation("observation") + spark.range(10) + .observe(namedObservation, count(lit(1)).as("rows")) + // Enforce PruneFilters to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + .collect() + + // This should produce the default value of metrics. Before SPARK-50007, the test fails + // because `namedObservation.getOrEmpty` is an empty Map. + assert(namedObservation.getOrEmpty.get("rows") === Some(0L)) + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala index f651bfb7f3c72..641f4668de3e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala @@ -586,4 +586,28 @@ class StreamingQueryOptimizationCorrectnessSuite extends StreamTest { } ) } + + test("SPARK-50007: default metrics is provided even observe node is pruned out") { + // We disable SPARK-49699 to test the case observe node is pruned out. + withSQLConf(SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .observe("observation", count(lit(1)).as("rows")) + // Enforce PruneFilters to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + val observeRow = qe.lastExecution.observedMetrics.get("observation") + // This should produce the default value of metrics. Before SPARK-50007, the test fails + // because `observeRow` is None. (Spark fails to find the metrics by name.) + assert(observeRow.get.getAs[Long]("rows") == 0L) + } + ) + } + } } From 7e9017ec8f6e6895a3251cc7250d6a72d63aed49 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Fri, 18 Oct 2024 12:51:52 +0200 Subject: [PATCH 049/108] [SPARK-50025][SQL] Integrate `_LEGACY_ERROR_TEMP_1253` into `EXPECT_VIEW_NOT_TABLE` ### What changes were proposed in this pull request? This PR proposes to Integrate `_LEGACY_ERROR_TEMP_1253` into `EXPECT_VIEW_NOT_TABLE` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48539 from itholic/LEGACY_1253. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 5 ----- .../spark/sql/errors/QueryCompilationErrors.scala | 10 +++++++--- .../org/apache/spark/sql/execution/command/ddl.scala | 4 +++- .../apache/spark/sql/hive/execution/HiveDDLSuite.scala | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 752524eab7cc7..992e3f8e9cc74 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6532,11 +6532,6 @@ " is not allowed on since its partition metadata is not stored in the Hive metastore. To import this information into the metastore, run `msck repair table `." ] }, - "_LEGACY_ERROR_TEMP_1253" : { - "message" : [ - "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead." - ] - }, "_LEGACY_ERROR_TEMP_1255" : { "message" : [ "Cannot drop built-in function ''." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 717ce4253acf1..149e6839424ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2834,10 +2834,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def cannotAlterTableWithAlterViewError(): Throwable = { + def cannotAlterTableWithAlterViewError(tableName: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1253", - messageParameters = Map.empty) + errorClass = "EXPECT_VIEW_NOT_TABLE.USE_ALTER_TABLE", + messageParameters = Map( + "operation" -> "ALTER VIEW", + "tableName" -> toSQLId(tableName) + ) + ) } def cannotOverwritePathBeingReadFromError(path: String): Throwable = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0b3469d3eb52d..3231c577c3498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -1012,7 +1012,9 @@ object DDLUtils extends Logging { viewName = tableMetadata.identifier.table ) case o if o != CatalogTableType.VIEW && isView => - throw QueryCompilationErrors.cannotAlterTableWithAlterViewError() + throw QueryCompilationErrors.cannotAlterTableWithAlterViewError( + tableName = tableMetadata.identifier.table + ) case _ => } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 94501d4c1c087..53a65e195e3f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -835,8 +835,8 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER VIEW $tabName RENAME TO $newViewName") }, - condition = "_LEGACY_ERROR_TEMP_1253", - parameters = Map.empty + condition = "EXPECT_VIEW_NOT_TABLE.USE_ALTER_TABLE", + parameters = Map("operation" -> "ALTER VIEW", "tableName" -> "`tab1`") ) checkError( From 7e3b1680b09f12edd37b21bdfab1e5916a2f64cb Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 18 Oct 2024 15:01:22 +0200 Subject: [PATCH 050/108] [SPARK-49990][SQL] Improve performance of `randStr` ### What changes were proposed in this pull request? The pr aims to - improve `performance` of `randStr`. - make the `non-codegen` and `codegen` implementations of `randStr` as consistent as possible. ### Why are the changes needed? Obviously, when creating `UTF8String`, using `UTF8String.fromBytes(bytes)` performs better than using `UTF8String.fromString(new String(bytes))`. - `UTF8String.fromBytes(bytes)` https://github.com/apache/spark/blob/39112e4f2f8c1401ffa73c84398d3b8f0afa211a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L160-L170 - `UTF8String.fromString(new String(bytes))` https://github.com/apache/spark/blob/39112e4f2f8c1401ffa73c84398d3b8f0afa211a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L179-L184 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually check. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48493 from panbingkun/SPARK-49990. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../expressions/ExpressionImplUtils.java | 19 ++++++++++ .../expressions/randomExpressions.scala | 35 ++----------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java index 18646f67975c0..2fad36efe8cc1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.util.VersionUtils; +import org.apache.spark.util.random.XORShiftRandom; /** * A utility class for constructing expressions. @@ -315,4 +316,22 @@ public static ArrayData getSentences( } return new GenericArrayData(res.toArray(new GenericArrayData[0])); } + + public static UTF8String randStr(XORShiftRandom rng, int length) { + byte[] bytes = new byte[length]; + for (int i = 0; i < bytes.length; i++) { + // We generate a random number between 0 and 61, inclusive. Between the 62 different choices + // we choose 0-9, a-z, or A-Z, where each category comprises 10 choices, 26 choices, or 26 + // choices, respectively (10 + 26 + 26 = 62). + int v = Math.abs(rng.nextInt() % 62); + if (v < 10) { + bytes[i] = (byte)('0' + v); + } else if (v < 36) { + bytes[i] = (byte)('a' + (v - 10)); + } else { + bytes[i] = (byte)('A' + (v - 36)); + } + } + return UTF8String.fromBytes(bytes); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 706dc675d7f24..7148d3738f7fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom /** @@ -399,23 +398,7 @@ case class RandStr( override def evalInternal(input: InternalRow): Any = { val numChars = length.eval(input).asInstanceOf[Number].intValue() - val bytes = new Array[Byte](numChars) - (0 until numChars).foreach { i => - // We generate a random number between 0 and 61, inclusive. Between the 62 different choices - // we choose 0-9, a-z, or A-Z, where each category comprises 10 choices, 26 choices, or 26 - // choices, respectively (10 + 26 + 26 = 62). - val num = (rng.nextInt() % 62).abs - num match { - case _ if num < 10 => - bytes.update(i, ('0' + num).toByte) - case _ if num < 36 => - bytes.update(i, ('a' + num - 10).toByte) - case _ => - bytes.update(i, ('A' + num - 36).toByte) - } - } - val result: UTF8String = UTF8String.fromBytes(bytes.toArray) - result + ExpressionImplUtils.randStr(rng, numChars) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -427,19 +410,8 @@ case class RandStr( ev.copy(code = code""" |${eval.code} - |int length = (int)(${eval.value}); - |char[] chars = new char[length]; - |for (int i = 0; i < length; i++) { - | int v = Math.abs($rngTerm.nextInt() % 62); - | if (v < 10) { - | chars[i] = (char)('0' + v); - | } else if (v < 36) { - | chars[i] = (char)('a' + (v - 10)); - | } else { - | chars[i] = (char)('A' + (v - 36)); - | } - |} - |UTF8String ${ev.value} = UTF8String.fromString(new String(chars)); + |UTF8String ${ev.value} = + | ${classOf[ExpressionImplUtils].getName}.randStr($rngTerm, (int)(${eval.value})); |boolean ${ev.isNull} = false; |""".stripMargin, isNull = FalseLiteral) @@ -452,4 +424,3 @@ object RandStr { def apply(length: Expression, seedExpression: Expression): RandStr = RandStr(length, seedExpression, hideSeed = false) } - From 4c6367255d751df0e457541bb4db45c32f92a0fa Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Fri, 18 Oct 2024 15:04:11 +0200 Subject: [PATCH 051/108] [SPARK-50004][SQL] Integrate `_LEGACY_ERROR_TEMP_3327` into `FIELD_NOT_FOUND` ### What changes were proposed in this pull request? This PR proposes to integrate `_LEGACY_ERROR_TEMP_3327` into `FIELD_NOT_FOUND` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48514 from itholic/SPARK-50004. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 5 -- .../sql/connector/catalog/CatalogV2Util.scala | 7 +- .../sql/connector/catalog/CatalogSuite.scala | 24 +++--- .../v2/V2SessionCatalogSuite.scala | 82 +++++++++---------- 4 files changed, 58 insertions(+), 60 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 992e3f8e9cc74..fb1439cfe1a5e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -8604,11 +8604,6 @@ "Cannot delete map key" ] }, - "_LEGACY_ERROR_TEMP_3227" : { - "message" : [ - "Cannot find field: " - ] - }, "_LEGACY_ERROR_TEMP_3228" : { "message" : [ "AFTER column not found: " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 9b7f68070a1a4..e1f114a6170a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -336,8 +336,11 @@ private[sql] object CatalogV2Util { return struct } else { throw new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_3227", - messageParameters = Map("fieldName" -> fieldNames.head)) + errorClass = "FIELD_NOT_FOUND", + messageParameters = Map( + "fieldName" -> toSQLId(fieldNames.head), + "fields" -> struct.fields.map(f => toSQLId(f.name)).mkString(", ")) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index aca6931a0688d..51ea945984b50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -381,8 +381,8 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.addColumn(Array("missing_col", "new_field"), StringType)) }, - condition = "_LEGACY_ERROR_TEMP_3227", - parameters = Map("fieldName" -> "missing_col")) + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: update column data type") { @@ -427,8 +427,8 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.updateColumnType(Array("missing_col"), LongType)) }, - condition = "_LEGACY_ERROR_TEMP_3227", - parameters = Map("fieldName" -> "missing_col")) + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: add comment") { @@ -478,8 +478,8 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("missing_col"), "comment")) }, - condition = "_LEGACY_ERROR_TEMP_3227", - parameters = Map("fieldName" -> "missing_col")) + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: rename top-level column") { @@ -546,8 +546,8 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.renameColumn(Array("missing_col"), "new_name")) }, - condition = "_LEGACY_ERROR_TEMP_3227", - parameters = Map("fieldName" -> "missing_col")) + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: multiple changes") { @@ -614,8 +614,8 @@ class CatalogSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false)) }, - condition = "_LEGACY_ERROR_TEMP_3227", - parameters = Map("fieldName" -> "missing_col")) + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) // with if exists it should pass catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), true)) @@ -636,8 +636,8 @@ class CatalogSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false)) }, - condition = "_LEGACY_ERROR_TEMP_3227", - parameters = Map("fieldName" -> "z")) + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`z`", "fields" -> "`x`, `y`")) // with if exists it should pass catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 8091d6e64fdc1..851dceeb8ac88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -26,7 +26,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -443,13 +443,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.columns === columns) - val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, - TableChange.addColumn(Array("missing_col", "new_field"), StringType)) - } - - assert(exc.getMessage.contains("missing_col")) - assert(exc.getMessage.contains("Cannot find")) + checkError( + exception = intercept[SparkIllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.addColumn(Array("missing_col", "new_field"), StringType)) + }, + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: update column data type") { @@ -498,13 +498,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.columns === columns) - val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, - TableChange.updateColumnType(Array("missing_col"), LongType)) - } - - assert(exc.getMessage.contains("missing_col")) - assert(exc.getMessage.contains("Cannot find")) + checkError( + exception = intercept[SparkIllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("missing_col"), LongType)) + }, + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: add comment") { @@ -554,13 +554,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.columns === columns) - val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, - TableChange.updateColumnComment(Array("missing_col"), "comment")) - } - - assert(exc.getMessage.contains("missing_col")) - assert(exc.getMessage.contains("Cannot find")) + checkError( + exception = intercept[SparkIllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("missing_col"), "comment")) + }, + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: rename top-level column") { @@ -628,13 +628,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.columns === columns) - val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, - TableChange.renameColumn(Array("missing_col"), "new_name")) - } - - assert(exc.getMessage.contains("missing_col")) - assert(exc.getMessage.contains("Cannot find")) + checkError( + exception = intercept[SparkIllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.renameColumn(Array("missing_col"), "new_name")) + }, + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) } test("alterTable: multiple changes") { @@ -702,12 +702,12 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.columns === columns) - val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false)) - } - - assert(exc.getMessage.contains("missing_col")) - assert(exc.getMessage.contains("Cannot find")) + checkError( + exception = intercept[SparkIllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false)) + }, + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`missing_col`", "fields" -> "`id`, `data`")) // with if exists it should pass catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), true)) @@ -725,12 +725,12 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.columns === tableColumns) - val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false)) - } - - assert(exc.getMessage.contains("z")) - assert(exc.getMessage.contains("Cannot find")) + checkError( + exception = intercept[SparkIllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false)) + }, + condition = "FIELD_NOT_FOUND", + parameters = Map("fieldName" -> "`z`", "fields" -> "`x`, `y`")) // with if exists it should pass catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), true)) From ff47dd9516e83f23d8ab3731286ef18ddfdaba62 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Oct 2024 15:09:38 +0200 Subject: [PATCH 052/108] [SPARK-50021][CORE][UI] Fix `ApplicationPage` to hide App UI links when UI is disabled ### What changes were proposed in this pull request? This PR aims to fix `ApplicationPage` to hide UI link when UI is disabled. ### Why are the changes needed? Previously, Spark throws `HTTP ERROR 500 java.lang.IllegalArgumentException: Invalid URI host: null (authority: null)` like the following **1. PREPARATION** ``` $ cat conf/spark-defaults.conf spark.ui.reverseProxy true spark.ui.reverseProxyUrl http://localhost:8080 $ sbin/start-master.sh $ sbin/start-worker.sh spark://$(hostname):7077 $ bin/spark-shell --master spark://$(hostname):7077 -c spark.ui.enabled=false ``` **2. BEFORE** Screenshot 2024-10-17 at 21 24 32 Screenshot 2024-10-17 at 21 24 51 **3. AFTER** Screenshot 2024-10-17 at 21 22 26 ### Does this PR introduce _any_ user-facing change? Yes, but previously it was a broken link. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48534 from dongjoon-hyun/SPARK-50021. Authored-by: Dongjoon Hyun Signed-off-by: Max Gekk --- .../spark/deploy/master/ui/ApplicationPage.scala | 12 ++++++++---- .../deploy/master/ui/ApplicationPageSuite.scala | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 98cc99c1a24b2..1a46688022341 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -94,10 +94,14 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
  • State: {app.state}
  • { if (!app.isFinished) { -
  • - Application Detail UI -
  • + if (app.desc.appUiUrl.isBlank()) { +
  • Application UI: Disabled
  • + } else { +
  • + Application Detail UI +
  • + } } else if (parent.master.historyServerUrl.nonEmpty) {
  • diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala index a9b96f85808d0..ccfc4ee1600a5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala @@ -36,13 +36,16 @@ class ApplicationPageSuite extends SparkFunSuite { private val rp = new ResourceProfile(Map.empty, Map.empty) private val desc = ApplicationDescription("name", Some(4), null, "appUiUrl", rp) + private val descWithoutUI = ApplicationDescription("name", Some(4), null, "", rp) private val appFinished = new ApplicationInfo(0, "app-finished", desc, new Date, null, 1) appFinished.markFinished(ApplicationState.FINISHED) private val appLive = new ApplicationInfo(0, "app-live", desc, new Date, null, 1) + private val appLiveWithoutUI = + new ApplicationInfo(0, "app-live-without-ui", descWithoutUI, new Date, null, 1) private val state = mock(classOf[MasterStateResponse]) when(state.completedApps).thenReturn(Array(appFinished)) - when(state.activeApps).thenReturn(Array(appLive)) + when(state.activeApps).thenReturn(Array(appLive, appLiveWithoutUI)) private val rpc = mock(classOf[RpcEndpointRef]) when(rpc.askSync[MasterStateResponse](RequestMasterState)).thenReturn(state) @@ -61,6 +64,16 @@ class ApplicationPageSuite extends SparkFunSuite { assert(!result.contains(master.historyServerUrl.get)) } + test("SPARK-50021: Application Detail UI is empty when spark.ui.enabled=false") { + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("appId")).thenReturn("app-live-without-ui") + + val result = new ApplicationPage(masterWebUI).render(request).toString() + assert(result.contains("Application UI: Disabled")) + assert(!result.contains("Application History UI")) + assert(!result.contains(master.historyServerUrl.get)) + } + test("SPARK-45774: Application History UI") { val request = mock(classOf[HttpServletRequest]) when(request.getParameter("appId")).thenReturn("app-finished") From a1fc7e66883ac61abfd4be15f50455d03cac040c Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 18 Oct 2024 15:13:05 +0200 Subject: [PATCH 053/108] [SPARK-49966][SQL] Use `Invoke` to implement `JsonToStructs`(`from_json`) ### What changes were proposed in this pull request? The pr aims to use `Invoke` to implement `JsonToStructs`(`from_json`). ### Why are the changes needed? Based on cloud-fan's suggestion, I believe that implementing `JsonToStructs`(`from_json`) with `Invoke` can greatly simplify the code. https://github.com/apache/spark/pull/48466#discussion_r1802533505 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Update existed UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48509 from panbingkun/SPARK-49966_FOLLOWUP. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../json/JsonExpressionEvalUtils.scala | 2 +- .../expressions/jsonExpressions.scala | 51 ++++++++----------- .../expressions/JsonExpressionsSuite.scala | 5 +- .../optimizer/OptimizeJsonExprsSuite.scala | 2 +- .../function_from_json.explain | 2 +- .../function_from_json_orphaned.explain | 2 +- ...unction_from_json_with_json_schema.explain | 2 +- .../BaseScriptTransformationExec.scala | 2 +- 8 files changed, 29 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index 6291e62304a38..efa5c930b73da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -57,7 +57,7 @@ object JsonExpressionEvalUtils { } } -class JsonToStructsEvaluator( +case class JsonToStructsEvaluator( options: Map[String, String], nullableSchema: DataType, nameOfCorruptRecord: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index a553336015b88..d884e76f5256d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator} -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf @@ -637,9 +637,9 @@ case class JsonToStructs( timeZoneId: Option[String] = None, variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)) extends UnaryExpression - with TimeZoneAwareExpression + with RuntimeReplaceable with ExpectsInputTypes - with NullIntolerant + with TimeZoneAwareExpression with QueryErrorsBase { // The JSON input data might be missing certain fields. We force the nullability @@ -649,7 +649,7 @@ case class JsonToStructs( override def nullable: Boolean = true - final override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT) + override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT, RUNTIME_REPLACEABLE) // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = @@ -683,32 +683,6 @@ case class JsonToStructs( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - @transient - private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) - - @transient - private lazy val evaluator = new JsonToStructsEvaluator( - options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys) - - override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String]) - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) - val eval = child.genCode(ctx) - val resultType = CodeGenerator.boxedType(dataType) - val resultTerm = ctx.freshName("result") - ev.copy(code = - code""" - |${eval.code} - |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value}); - |boolean ${ev.isNull} = $resultTerm == null; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${ev.isNull}) { - | ${ev.value} = $resultTerm; - |} - |""".stripMargin) - } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil override def sql: String = schema match { @@ -718,6 +692,21 @@ case class JsonToStructs( override def prettyName: String = "from_json" + @transient + private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + + @transient + lazy val evaluator: JsonToStructsEvaluator = JsonToStructsEvaluator( + options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys) + + override def replacement: Expression = Invoke( + Literal.create(evaluator, ObjectType(classOf[JsonToStructsEvaluator])), + "evaluate", + dataType, + Seq(child), + Seq(child.dataType) + ) + override protected def withNewChildInternal(newChild: Expression): JsonToStructs = copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 0afaf4ec097c8..edb7b93ecdf68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -420,7 +420,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json escaping") { val schema = StructType(StructField("\"quote", IntegerType) :: Nil) GenerateUnsafeProjection.generate( - JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT) :: Nil) + JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT).replacement :: Nil) } test("from_json") { @@ -729,7 +729,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from/to json - interval support") { val schema = StructType(StructField("i", CalendarIntervalType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType)), + JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType), + UTC_OPT), InternalRow(new CalendarInterval(12, 1, 0))) Seq(MapType(CalendarIntervalType, IntegerType), MapType(IntegerType, CalendarIntervalType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index eed06da609f8e..7af2be2db01d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -292,7 +292,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""", null).foreach(v => { val row = create_row(v) - checkEvaluation(e1, e2.eval(row), row) + checkEvaluation(e1, replace(e2).eval(row), row) }) } diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain index 8d1d122d156ff..9bc33b3b97d2c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain index 8d1d122d156ff..9bc33b3b97d2c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain index 8d1d122d156ff..9bc33b3b97d2c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index af3a8d67e3c29..2a1554d287a8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -239,7 +239,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { val complexTypeFactory = JsonToStructs(attr.dataType, ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone)) wrapperConvertException(data => - complexTypeFactory.nullSafeEval(UTF8String.fromString(data)), any => any) + complexTypeFactory.evaluator.evaluate(UTF8String.fromString(data)), any => any) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) case dt => From ae87ce6c7f9e3e765c171043fccb676ac9943a4d Mon Sep 17 00:00:00 2001 From: "zhipeng.mao" Date: Fri, 18 Oct 2024 15:25:36 +0200 Subject: [PATCH 054/108] [SPARK-50026][SQL] Fix Identity Column metadata get ### What changes were proposed in this pull request? Change identity column get metadata from getString.toLong/toBoolean to getLong/getBoolean. ### Why are the changes needed? The metadata is stored as Long and Boolean rather than String. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48540 from zhipengmao-db/zhipengmao-db/fix-id-column. Authored-by: zhipeng.mao Signed-off-by: Max Gekk --- .../org/apache/spark/sql/catalyst/util/IdentityColumn.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala index 26a3cb026d317..07ab8731de891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala @@ -61,9 +61,9 @@ object IdentityColumn { def getIdentityInfo(field: StructField): Option[IdentityColumnSpec] = { if (isIdentityColumn(field)) { Some(new IdentityColumnSpec( - field.metadata.getString(IDENTITY_INFO_START).toLong, - field.metadata.getString(IDENTITY_INFO_STEP).toLong, - field.metadata.getString(IDENTITY_INFO_ALLOW_EXPLICIT_INSERT).toBoolean)) + field.metadata.getLong(IDENTITY_INFO_START), + field.metadata.getLong(IDENTITY_INFO_STEP), + field.metadata.getBoolean(IDENTITY_INFO_ALLOW_EXPLICIT_INSERT))) } else { None } From 9d0e31dda53dced0a61601d652d750706a18b11e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Oct 2024 08:48:13 -0700 Subject: [PATCH 055/108] [SPARK-50022][CORE][UI] Fix `MasterPage` to hide App UI links when UI is disabled ### What changes were proposed in this pull request? This PR aims to fix `MasterPage` to hide App UI links when UI is disabled. Previously, the link leads the following errors if a user clicks it. Screenshot 2024-10-17 at 22 06 22 ### Why are the changes needed? **1. PREPARATION** ``` $ cat conf/spark-defaults.conf spark.ui.reverseProxy true spark.ui.reverseProxyUrl http://localhost:8080 $ sbin/start-master.sh $ sbin/start-worker.sh spark://$(hostname):7077 $ bin/spark-shell --master spark://$(hostname):7077 -c spark.ui.enabled=false ``` **2. BEFORE** Screenshot 2024-10-17 at 22 01 16 **3. AFTER** Screenshot 2024-10-17 at 22 04 12 ### Does this PR introduce _any_ user-facing change? The previous behavior shows HTTP 500 error. ### How was this patch tested? Pass the CIs with a newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48535 from dongjoon-hyun/SPARK-50022. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../master/ui/ReadOnlyMasterWebUISuite.scala | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 1d15088b5c546..a396444ebe9c5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -316,7 +316,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { { - if (app.isFinished) { + if (app.isFinished || app.desc.appUiUrl.isBlank()) { app.desc.name } else { SPARK_SCALA_VERSION2.1")) } + + test("SPARK-50022: Fix 'MasterPage' to hide App UI links when UI is disabled") { + val url = s"http://${Utils.localHostNameForURI()}:${masterWebUI.boundPort}/" + val conn = sendHttpRequest(url, "GET") + assert(conn.getResponseCode === SC_OK) + val result = Source.fromInputStream(conn.getInputStream).mkString + assert(result.contains("WithUI")) + assert(result.contains(" WithoutUI\n")) + } } From aaecab3d3c8b116e4aa32b2b26ad6a1b32f2a80a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Oct 2024 09:59:33 -0700 Subject: [PATCH 056/108] Revert "[SPARK-50011][INFRA][FOLLOW-UP] Refresh the image cache job" This reverts commit 2724909656fe5cebe1850f0fd81e32a998eed07e. --- .github/workflows/build_infra_images_cache.yml | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml index 18e1e43f36c75..18419334836b2 100644 --- a/.github/workflows/build_infra_images_cache.yml +++ b/.github/workflows/build_infra_images_cache.yml @@ -26,8 +26,7 @@ on: - 'master' - 'branch-*' paths: - - 'dev/infra/base/Dockerfile' - - 'dev/infra/docs/Dockerfile' + - 'dev/infra/Dockerfile' - '.github/workflows/build_infra_images_cache.yml' # Create infra image when cutting down branches/tags create: @@ -54,21 +53,10 @@ jobs: id: docker_build uses: docker/build-push-action@v6 with: - context: ./dev/infra/base/ + context: ./dev/infra/ push: true tags: ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }}-static cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }} cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }},mode=max - name: Image digest run: echo ${{ steps.docker_build.outputs.digest }} - - name: Build and push (Documentation) - id: docker_build_docs - uses: docker/build-push-action@v6 - with: - context: ./dev/infra/docs/ - push: true - tags: ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }}-static - cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }} - cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }},mode=max - - name: Image digest (Documentation) - run: echo ${{ steps.docker_build_docs.outputs.digest }} From 6f710cd9c561c911066108b7d659d5de44099757 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Oct 2024 10:08:07 -0700 Subject: [PATCH 057/108] Revert "[SPARK-50011][INFRA] Add a separate docker file for doc build" This reverts commit 360ce0a82bc9675982ef77dd24310a4432e74b62. --- .github/workflows/build_and_test.yml | 38 +++++------- dev/infra/{base => }/Dockerfile | 0 dev/infra/docs/Dockerfile | 91 ---------------------------- 3 files changed, 14 insertions(+), 115 deletions(-) rename dev/infra/{base => }/Dockerfile (100%) delete mode 100644 dev/infra/docs/Dockerfile diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index eadcaaedc5829..14d93a498fc59 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -58,7 +58,6 @@ jobs: outputs: required: ${{ steps.set-outputs.outputs.required }} image_url: ${{ steps.infra-image-outputs.outputs.image_url }} - image_docs_url: ${{ steps.infra-image-docs-outputs.outputs.image_docs_url }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -135,14 +134,6 @@ jobs: IMG_NAME="apache-spark-ci-image:${{ inputs.branch }}-${{ github.run_id }}" IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" echo "image_url=$IMG_URL" >> $GITHUB_OUTPUT - - name: Generate infra image URL (Documentation) - id: infra-image-docs-outputs - run: | - # Convert to lowercase to meet Docker repo name requirement - REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]') - IMG_NAME="apache-spark-ci-image-docs:${{ inputs.branch }}-${{ github.run_id }}" - IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" - echo "image_docs_url=$IMG_URL" >> $GITHUB_OUTPUT # Build: build Spark and run the tests for specified modules. build: @@ -354,23 +345,12 @@ jobs: id: docker_build uses: docker/build-push-action@v6 with: - context: ./dev/infra/base/ + context: ./dev/infra/ push: true tags: | ${{ needs.precondition.outputs.image_url }} # Use the infra image cache to speed up cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ inputs.branch }} - - name: Build and push (Documentation) - id: docker_build_docs - uses: docker/build-push-action@v6 - with: - context: ./dev/infra/docs/ - push: true - tags: | - ${{ needs.precondition.outputs.image_docs_url }} - # Use the infra image cache to speed up - cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ inputs.branch }} - pyspark: needs: [precondition, infra-image] @@ -803,7 +783,7 @@ jobs: PYSPARK_PYTHON: python3.9 GITHUB_PREV_SHA: ${{ github.event.before }} container: - image: ${{ needs.precondition.outputs.image_docs_url }} + image: ${{ needs.precondition.outputs.image_url }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -853,8 +833,18 @@ jobs: with: distribution: zulu java-version: ${{ inputs.java }} - - name: List Python packages - run: python3.9 -m pip list + - name: Install Python dependencies for python linter and documentation generation + if: inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5' + run: | + # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 + # See 'ipython_genutils' in SPARK-38517 + # See 'docutils<0.18.0' in SPARK-39421 + python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' + python3.9 -m pip list - name: Install dependencies for documentation generation for branch-3.4, branch-3.5 if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' run: | diff --git a/dev/infra/base/Dockerfile b/dev/infra/Dockerfile similarity index 100% rename from dev/infra/base/Dockerfile rename to dev/infra/Dockerfile diff --git a/dev/infra/docs/Dockerfile b/dev/infra/docs/Dockerfile deleted file mode 100644 index 8a8e1680182c5..0000000000000 --- a/dev/infra/docs/Dockerfile +++ /dev/null @@ -1,91 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Image for building and testing Spark branches. Based on Ubuntu 22.04. -# See also in https://hub.docker.com/_/ubuntu -FROM ubuntu:jammy-20240227 -LABEL org.opencontainers.image.authors="Apache Spark project " -LABEL org.opencontainers.image.licenses="Apache-2.0" -LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for Documentation" -# Overwrite this label to avoid exposing the underlying Ubuntu OS version label -LABEL org.opencontainers.image.version="" - -ENV FULL_REFRESH_DATE 20241016 - -ENV DEBIAN_FRONTEND noninteractive -ENV DEBCONF_NONINTERACTIVE_SEEN true - -RUN apt-get update && apt-get install -y \ - build-essential \ - ca-certificates \ - curl \ - gfortran \ - git \ - gnupg \ - libcurl4-openssl-dev \ - libfontconfig1-dev \ - libfreetype6-dev \ - libfribidi-dev \ - libgit2-dev \ - libharfbuzz-dev \ - libjpeg-dev \ - liblapack-dev \ - libopenblas-dev \ - libpng-dev \ - libpython3-dev \ - libssl-dev \ - libtiff5-dev \ - libxml2-dev \ - nodejs \ - npm \ - openjdk-17-jdk-headless \ - pandoc \ - pkg-config \ - qpdf \ - r-base \ - ruby \ - ruby-dev \ - software-properties-common \ - wget \ - zlib1g-dev \ - && rm -rf /var/lib/apt/lists/* - - -# See more in SPARK-39959, roxygen2 < 7.2.1 -RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', 'rmarkdown', 'testthat'), repos='https://cloud.r-project.org/')" && \ - Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \ - Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \ - Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" - -# See more in SPARK-39735 -ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" - -# Install Python 3.9 -RUN add-apt-repository ppa:deadsnakes/ppa -RUN apt-get update && apt-get install -y python3.9 python3.9-distutils \ - && rm -rf /var/lib/apt/lists/* -RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 - -# Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 -# See 'ipython_genutils' in SPARK-38517 -# See 'docutils<0.18.0' in SPARK-39421 -RUN python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ - 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ - 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ - 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' \ - && python3.9 -m pip cache purge From b1d1f10f96b1a92037a0205854745efeac5717ea Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Sat, 19 Oct 2024 06:22:02 +0900 Subject: [PATCH 058/108] [SPARK-49846][SS] Add numUpdatedStateRows and numRemovedStateRows metrics for use with transformWithState operator ### What changes were proposed in this pull request? Add numUpdatedStateRows and numRemovedStateRows metrics for use with transformWithState operator ### Why are the changes needed? Without this change, metrics around these operations are not available in the query progress metrics ### Does this PR introduce _any_ user-facing change? No Metrics updated as part of the streaming query progress ``` "operatorName" : "transformWithStateExec", "numRowsTotal" : 1, "numRowsUpdated" : 1, "numRowsRemoved" : 1, ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 25 seconds, 697 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48317 from anishshri-db/task/SPARK-49846. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../execution/streaming/ListStateImpl.scala | 37 ++++++-- .../streaming/ListStateImplWithTTL.scala | 38 +++++++- .../streaming/ListStateMetricsImpl.scala | 86 +++++++++++++++++++ .../execution/streaming/MapStateImpl.scala | 19 +++- .../streaming/MapStateImplWithTTL.scala | 10 ++- .../StatefulProcessorHandleImpl.scala | 54 ++++++++---- .../execution/streaming/ValueStateImpl.scala | 7 +- .../streaming/ValueStateImplWithTTL.scala | 8 +- .../TransformWithListStateSuite.scala | 2 + .../TransformWithListStateTTLSuite.scala | 9 +- .../TransformWithMapStateSuite.scala | 4 + .../TransformWithMapStateTTLSuite.scala | 11 ++- .../streaming/TransformWithStateSuite.scala | 5 ++ .../TransformWithValueStateTTLSuite.scala | 14 ++- 14 files changed, 266 insertions(+), 38 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 497472ce63676..77c481a8ba0ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState +import org.apache.spark.sql.types.StructType /** * Provides concrete implementation for list of values associated with a state variable @@ -30,14 +32,22 @@ import org.apache.spark.sql.streaming.ListState * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored in the list */ class ListStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) - extends ListState[S] with Logging { + valEncoder: Encoder[S], + metrics: Map[String, SQLMetric] = Map.empty) + extends ListStateMetricsImpl + with ListState[S] + with Logging { + + override def stateStore: StateStore = store + override def baseStateName: String = stateName + override def exprEncSchema: StructType = keyExprEnc.schema private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName) @@ -76,6 +86,8 @@ class ListStateImpl[S]( val encodedKey = stateTypesEncoder.encodeGroupingKey() var isFirst = true + var entryCount = 0L + TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows") newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v) @@ -83,16 +95,23 @@ class ListStateImpl[S]( store.put(encodedKey, encodedValue, stateName) isFirst = false } else { - store.merge(encodedKey, encodedValue, stateName) + store.merge(encodedKey, encodedValue, stateName) } + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } + updateEntryCount(encodedKey, entryCount) } /** Append an entry to the list. */ override def appendValue(newState: S): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) - store.merge(stateTypesEncoder.encodeGroupingKey(), + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val entryCount = getEntryCount(encodedKey) + store.merge(encodedKey, stateTypesEncoder.encodeValue(newState), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") + updateEntryCount(encodedKey, entryCount + 1) } /** Append an entire list to the existing value. */ @@ -100,15 +119,23 @@ class ListStateImpl[S]( validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() + var entryCount = getEntryCount(encodedKey) newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v) store.merge(encodedKey, encodedValue, stateName) + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } + updateEntryCount(encodedKey, entryCount) } /** Remove this state. */ override def clear(): Unit = { - store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + val encodedKey = stateTypesEncoder.encodeGroupingKey() + store.remove(encodedKey, stateName) + val entryCount = getEntryCount(encodedKey) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) + removeEntryCount(encodedKey) } private def validateNewState(newState: Array[S]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index faeec7cb93863..be47f566bc6a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator /** @@ -34,6 +36,7 @@ import org.apache.spark.util.NextIterator * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ class ListStateImplWithTTL[S]( @@ -42,9 +45,15 @@ class ListStateImplWithTTL[S]( keyExprEnc: ExpressionEncoder[Any], valEncoder: Encoder[S], ttlConfig: TTLConfig, - batchTimestampMs: Long) - extends SingleKeyTTLStateImpl( - stateName, store, keyExprEnc, batchTimestampMs) with ListState[S] { + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty) + extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs) + with ListStateMetricsImpl + with ListState[S] { + + override def stateStore: StateStore = store + override def baseStateName: String = stateName + override def exprEncSchema: StructType = keyExprEnc.schema private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) @@ -99,6 +108,8 @@ class ListStateImplWithTTL[S]( val encodedKey = stateTypesEncoder.encodeGroupingKey() var isFirst = true + var entryCount = 0L + TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows") newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs) @@ -108,17 +119,23 @@ class ListStateImplWithTTL[S]( } else { store.merge(encodedKey, encodedValue, stateName) } + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } upsertTTLForStateKey(encodedKey) + updateEntryCount(encodedKey, entryCount) } /** Append an entry to the list. */ override def appendValue(newState: S): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) val encodedKey = stateTypesEncoder.encodeGroupingKey() + val entryCount = getEntryCount(encodedKey) store.merge(encodedKey, stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(encodedKey) + updateEntryCount(encodedKey, entryCount + 1) } /** Append an entire list to the existing value. */ @@ -126,16 +143,24 @@ class ListStateImplWithTTL[S]( validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() + var entryCount = getEntryCount(encodedKey) newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs) store.merge(encodedKey, encodedValue, stateName) + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } upsertTTLForStateKey(encodedKey) + updateEntryCount(encodedKey, entryCount) } /** Remove this state. */ override def clear(): Unit = { - store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + val encodedKey = stateTypesEncoder.encodeGroupingKey() + store.remove(encodedKey, stateName) + val entryCount = getEntryCount(encodedKey) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) + removeEntryCount(encodedKey) clearTTLState() } @@ -158,7 +183,9 @@ class ListStateImplWithTTL[S]( val unsafeRowValuesIterator = store.valuesIterator(groupingKey, stateName) // We clear the list, and use the iterator to put back all of the non-expired values store.remove(groupingKey, stateName) + removeEntryCount(groupingKey) var isFirst = true + var entryCount = 0L unsafeRowValuesIterator.foreach { encodedValue => if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) { if (isFirst) { @@ -167,10 +194,13 @@ class ListStateImplWithTTL[S]( } else { store.merge(groupingKey, encodedValue, stateName) } + entryCount += 1 } else { numValuesExpired += 1 } } + updateEntryCount(groupingKey, entryCount) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", numValuesExpired) numValuesExpired } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala new file mode 100644 index 0000000000000..ea43c3f4fcd3b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.types._ + +/** + * Trait that provides helper methods to maintain metrics for a list state. + * For list state, we keep track of the count of entries in the list in a separate column family + * to get an accurate view of the number of entries that are updated/removed from the list and + * reported as part of the query progress metrics. + */ +trait ListStateMetricsImpl { + def stateStore: StateStore + + def baseStateName: String + + def exprEncSchema: StructType + + // We keep track of the count of entries in the list in a separate column family + // to avoid scanning the entire list to get the count. + private val counterCFValueSchema: StructType = + StructType(Seq(StructField("count", LongType, nullable = false))) + + private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema) + + private val updatedCountRow = new GenericInternalRow(1) + + private def getRowCounterCFName(stateName: String) = "$rowCounter_" + stateName + + stateStore.createColFamilyIfAbsent(getRowCounterCFName(baseStateName), exprEncSchema, + counterCFValueSchema, NoPrefixKeyStateEncoderSpec(exprEncSchema), isInternal = true) + + /** + * Function to get the number of entries in the list state for a given grouping key + * @param encodedKey - encoded grouping key + * @return - number of entries in the list state + */ + def getEntryCount(encodedKey: UnsafeRow): Long = { + val countRow = stateStore.get(encodedKey, getRowCounterCFName(baseStateName)) + if (countRow != null) { + countRow.getLong(0) + } else { + 0L + } + } + + /** + * Function to update the number of entries in the list state for a given grouping key + * @param encodedKey - encoded grouping key + * @param updatedCount - updated count of entries in the list state + */ + def updateEntryCount( + encodedKey: UnsafeRow, + updatedCount: Long): Unit = { + updatedCountRow.setLong(0, updatedCount) + stateStore.put(encodedKey, + counterCFProjection(updatedCountRow.asInstanceOf[InternalRow]), + getRowCounterCFName(baseStateName)) + } + + /** + * Function to remove the number of entries in the list state for a given grouping key + * @param encodedKey - encoded grouping key + */ + def removeEntryCount(encodedKey: UnsafeRow): Unit = { + stateStore.remove(encodedKey, getRowCounterCFName(baseStateName)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 2fa6fa415a77b..cb3db19496dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -19,17 +19,30 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState import org.apache.spark.sql.types.StructType +/** + * Class that provides a concrete implementation for map state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyExprEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @param metrics - metrics to be updated as part of stateful processing + * @tparam K - type of key for map state variable + * @tparam V - type of value for map state variable + */ class MapStateImpl[K, V]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], userKeyEnc: Encoder[K], - valEncoder: Encoder[V]) extends MapState[K, V] with Logging { + valEncoder: Encoder[V], + metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging { // Pack grouping key and user key together as a prefixed composite key private val schemaForCompositeKeyRow: StructType = { @@ -70,6 +83,7 @@ class MapStateImpl[K, V]( val encodedValue = stateTypesEncoder.encodeValue(value) val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key) store.put(encodedCompositeKey, encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } /** Get the map associated with grouping key */ @@ -98,6 +112,9 @@ class MapStateImpl[K, V]( StateStoreErrors.requireNonNullStateValue(key, stateName) val compositeKey = stateTypesEncoder.encodeCompositeKey(key) store.remove(compositeKey, stateName) + // Note that for mapState, the rows are flattened. So we count the number of rows removed + // proportional to the number of keys in the map per grouping key. + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } /** Remove this state. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index a6234636a94f7..6a3685ad6c46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -20,6 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} @@ -35,6 +36,7 @@ import org.apache.spark.util.NextIterator * @param valEncoder - SQL encoder for state variable * @param ttlConfig - the ttl configuration (time to live duration etc.) * @param batchTimestampMs - current batch processing timestamp. + * @param metrics - metrics to be updated as part of stateful processing * @tparam K - type of key for map state variable * @tparam V - type of value for map state variable * @return - instance of MapState of type [K,V] that can be used to store state persistently @@ -46,7 +48,8 @@ class MapStateImplWithTTL[K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V], ttlConfig: TTLConfig, - batchTimestampMs: Long) + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty) extends CompositeKeyTTLStateImpl[K](stateName, store, keyExprEnc, userKeyEnc, batchTimestampMs) with MapState[K, V] with Logging { @@ -106,6 +109,7 @@ class MapStateImplWithTTL[K, V]( val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs) val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key) store.put(encodedCompositeKey, encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(ttlExpirationMs, encodedGroupingKey, encodedUserKey) } @@ -149,6 +153,9 @@ class MapStateImplWithTTL[K, V]( StateStoreErrors.requireNonNullStateValue(key, stateName) val compositeKey = stateTypesEncoder.encodeCompositeKey(key) store.remove(compositeKey, stateName) + // Note that for mapState, the rows are flattened. So we count the number of rows removed + // proportional to the number of keys in the map per grouping key. + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } /** Remove this state. */ @@ -184,6 +191,7 @@ class MapStateImplWithTTL[K, V]( if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) { store.remove(compositeKeyRow, stateName) numRemovedElements += 1 + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } } numRemovedElements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 8beacbec7e6ef..762dfc7d08920 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -45,6 +45,24 @@ object ImplicitGroupingKeyTracker { def removeImplicitKey(): Unit = implicitKey.remove() } +/** + * Utility object to perform metrics updates + */ +object TWSMetricsUtils { + def resetMetric( + metrics: Map[String, SQLMetric], + metricName: String): Unit = { + metrics.get(metricName).foreach(_.reset()) + } + + def incrementMetric( + metrics: Map[String, SQLMetric], + metricName: String, + countValue: Long = 1L): Unit = { + metrics.get(metricName).foreach(_.add(countValue)) + } +} + /** * Enum used to track valid states for the StatefulProcessorHandle */ @@ -117,16 +135,12 @@ class StatefulProcessorHandleImpl( private lazy val currQueryInfo: QueryInfo = buildQueryInfo() - private def incrementMetric(metricName: String): Unit = { - metrics.get(metricName).foreach(_.add(1)) - } - override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", CREATED) - incrementMetric("numValueStateVars") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars") resultState } @@ -139,9 +153,10 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numValueStateWithTTLVars") + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics) ttlStates.add(valueStateWithTTL) + TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars") + valueStateWithTTL } @@ -155,8 +170,8 @@ class StatefulProcessorHandleImpl( */ override def registerTimer(expiryTimestampMs: Long): Unit = { verifyTimerOperations("register_timer") - incrementMetric("numRegisteredTimers") timerState.registerTimer(expiryTimestampMs) + TWSMetricsUtils.incrementMetric(metrics, "numRegisteredTimers") } /** @@ -165,8 +180,8 @@ class StatefulProcessorHandleImpl( */ override def deleteTimer(expiryTimestampMs: Long): Unit = { verifyTimerOperations("delete_timer") - incrementMetric("numDeletedTimers") timerState.deleteTimer(expiryTimestampMs) + TWSMetricsUtils.incrementMetric(metrics, "numDeletedTimers") } /** @@ -211,14 +226,14 @@ class StatefulProcessorHandleImpl( override def deleteIfExists(stateName: String): Unit = { verifyStateVarOperations("delete_if_exists", CREATED) if (store.removeColFamilyIfExists(stateName)) { - incrementMetric("numDeletedStateVars") + TWSMetricsUtils.incrementMetric(metrics, "numDeletedStateVars") } } override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state", CREATED) - incrementMetric("numListStateVars") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numListStateVars") resultState } @@ -247,8 +262,8 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numListStateWithTTLVars") + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars") ttlStates.add(listStateWithTTL) listStateWithTTL @@ -259,8 +274,9 @@ class StatefulProcessorHandleImpl( userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { verifyStateVarOperations("get_map_state", CREATED) - incrementMetric("numMapStateVars") - val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, + userKeyEnc, valEncoder, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars") resultState } @@ -274,8 +290,8 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numMapStateWithTTLVars") + valEncoder, ttlConfig, batchTimestampMs.get, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) mapStateWithTTL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 63cac4a3b68cb..b1b87feeb263b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState @@ -29,13 +30,15 @@ import org.apache.spark.sql.streaming.ValueState * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) + valEncoder: Encoder[S], + metrics: Map[String, SQLMetric] = Map.empty) extends ValueState[S] with Logging { private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName) @@ -74,10 +77,12 @@ class ValueStateImpl[S]( val encodedValue = stateTypesEncoder.encodeValue(newState) store.put(stateTypesEncoder.encodeGroupingKey(), encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index c6d11b155866b..145cd90264910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} @@ -33,6 +34,7 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState} * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ class ValueStateImplWithTTL[S]( @@ -41,7 +43,8 @@ class ValueStateImplWithTTL[S]( keyExprEnc: ExpressionEncoder[Any], valEncoder: Encoder[S], ttlConfig: TTLConfig, - batchTimestampMs: Long) + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty) extends SingleKeyTTLStateImpl( stateName, store, keyExprEnc, batchTimestampMs) with ValueState[S] { @@ -92,12 +95,14 @@ class ValueStateImplWithTTL[S]( val serializedGroupingKey = stateTypesEncoder.encodeGroupingKey() store.put(serializedGroupingKey, encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey) } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") clearTTLState() } @@ -108,6 +113,7 @@ class ValueStateImplWithTTL[S]( if (retRow != null) { if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) { store.remove(groupingKey, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") result = 1L } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index dea16e5298975..71b8c8ac923d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -301,6 +301,8 @@ class TransformWithListStateSuite extends StreamTest CheckNewAnswer(("k5", "v5"), ("k5", "v6")), Execute { q => assert(q.lastProgress.stateOperators(0).customMetrics.get("numListStateVars") > 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 2) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 2) } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index 299a3346b2e51..d11d8ef9a9b36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -147,6 +147,9 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { ), AdvanceManualClock(1 * 1000), CheckNewAnswer(), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 3) + }, // get ttl values AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), AdvanceManualClock(1 * 1000), @@ -158,15 +161,17 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { OutputEvent("k1", 5, isTTLValue = true, 109000), OutputEvent("k1", 6, isTTLValue = true, 109000) ), + AddData(inputStream, InputEvent("k1", "get", -1, null)), // advance clock to expire the first three elements AdvanceManualClock(15 * 1000), // batch timestamp: 65000 - AddData(inputStream, InputEvent("k1", "get", -1, null)), - AdvanceManualClock(1 * 1000), CheckNewAnswer( OutputEvent("k1", 4, isTTLValue = false, -1), OutputEvent("k1", 5, isTTLValue = false, -1), OutputEvent("k1", 6, isTTLValue = false, -1) ), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 3) + }, // ensure that expired elements are no longer in state AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null)), AdvanceManualClock(1 * 1000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index fe88fbaa91cb7..e4e6862f7f937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -209,9 +209,13 @@ class TransformWithMapStateSuite extends StreamTest AddData(inputData, InputMapRow("k2", "iterator", ("", ""))), CheckNewAnswer(), AddData(inputData, InputMapRow("k2", "exists", ("", ""))), + AddData(inputData, InputMapRow("k1", "clear", ("", ""))), + AddData(inputData, InputMapRow("k3", "updateValue", ("v7", "11"))), CheckNewAnswer(("k2", "exists", "false")), Execute { q => assert(q.lastProgress.stateOperators(0).customMetrics.get("numMapStateVars") > 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index bf46c802fdea4..3794bcc9ea271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -210,12 +210,17 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { AddData(inputStream, MapInputEvent("k1", "key2", "put", 2)), AdvanceManualClock(1 * 1000), CheckNewAnswer(), - // advance clock to expire first key - AdvanceManualClock(30 * 1000), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, AddData(inputStream, MapInputEvent("k1", "key1", "get", -1), MapInputEvent("k1", "key2", "get", -1)), - AdvanceManualClock(1 * 1000), + // advance clock to expire first key + AdvanceManualClock(30 * 1000), CheckNewAnswer(MapOutputEvent("k1", "key2", 2, isTTLValue = false, -1)), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, StopStream ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 0c02fbf97820b..257578ee65447 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -528,6 +528,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest Execute { q => assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) }, AddData(inputData, "a", "b"), CheckNewAnswer(("a", "2"), ("b", "1")), @@ -536,6 +537,10 @@ class TransformWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a CheckNewAnswer(("b", "2")), StopStream, + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and CheckNewAnswer(("a", "1"), ("c", "1")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 1fbeaeb817bd9..e2b31de1f66b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -247,7 +247,19 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { // validate ttl value is removed in the value state column family AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1)), AdvanceManualClock(1 * 1000), - CheckNewAnswer() + CheckNewAnswer(), + AddData(inputStream, InputEvent(ttlKey, "put", 3)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, + AddData(inputStream, InputEvent(noTtlKey, "get", -1)), + AdvanceManualClock(60 * 1000), + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + } ) } } From f8d92224b9af4ffffbb83ca2c9dd3c3b909b135d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Oct 2024 16:31:22 -0700 Subject: [PATCH 059/108] [SPARK-49550][BUILD] Upgrade Hadoop to 3.4.1 ### What changes were proposed in this pull request? This PR aims to upgrade Apache Hadoop to 3.4.1. ### Why are the changes needed? To bring the latest bug fixes. - http://hadoop.apache.org/docs/r3.4.1/index.html - http://hadoop.apache.org/docs/r3.4.1/hadoop-project-dist/hadoop-common/release/3.4.1/RELEASENOTES.3.4.1.html - http://hadoop.apache.org/docs/r3.4.1/hadoop-project-dist/hadoop-common/release/3.4.1/CHANGELOG.3.4.1.html ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48295 from dongjoon-hyun/SPARK-49550. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 20 ++++++++++---------- pom.xml | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 4620a51904461..5bb578a807124 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -71,16 +71,16 @@ gcs-connector/hadoop3-2.2.25/shaded/gcs-connector-hadoop3-2.2.25-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.11.0//gson-2.11.0.jar guava/33.2.1-jre//guava-33.2.1-jre.jar -hadoop-aliyun/3.4.0//hadoop-aliyun-3.4.0.jar -hadoop-annotations/3.4.0//hadoop-annotations-3.4.0.jar -hadoop-aws/3.4.0//hadoop-aws-3.4.0.jar -hadoop-azure-datalake/3.4.0//hadoop-azure-datalake-3.4.0.jar -hadoop-azure/3.4.0//hadoop-azure-3.4.0.jar -hadoop-client-api/3.4.0//hadoop-client-api-3.4.0.jar -hadoop-client-runtime/3.4.0//hadoop-client-runtime-3.4.0.jar -hadoop-cloud-storage/3.4.0//hadoop-cloud-storage-3.4.0.jar -hadoop-huaweicloud/3.4.0//hadoop-huaweicloud-3.4.0.jar -hadoop-shaded-guava/1.2.0//hadoop-shaded-guava-1.2.0.jar +hadoop-aliyun/3.4.1//hadoop-aliyun-3.4.1.jar +hadoop-annotations/3.4.1//hadoop-annotations-3.4.1.jar +hadoop-aws/3.4.1//hadoop-aws-3.4.1.jar +hadoop-azure-datalake/3.4.1//hadoop-azure-datalake-3.4.1.jar +hadoop-azure/3.4.1//hadoop-azure-3.4.1.jar +hadoop-client-api/3.4.1//hadoop-client-api-3.4.1.jar +hadoop-client-runtime/3.4.1//hadoop-client-runtime-3.4.1.jar +hadoop-cloud-storage/3.4.1//hadoop-cloud-storage-3.4.1.jar +hadoop-huaweicloud/3.4.1//hadoop-huaweicloud-3.4.1.jar +hadoop-shaded-guava/1.3.0//hadoop-shaded-guava-1.3.0.jar hive-beeline/2.3.10//hive-beeline-2.3.10.jar hive-cli/2.3.10//hive-cli-2.3.10.jar hive-common/2.3.10//hive-common-2.3.10.jar diff --git a/pom.xml b/pom.xml index fe49568d744a0..588e3f5e4161a 100644 --- a/pom.xml +++ b/pom.xml @@ -123,7 +123,7 @@ 2.0.16 2.24.1 - 3.4.0 + 3.4.1 3.25.5 3.11.4 From 14ed86e24dba00d7a87f9140c6e62e9cf2554a5b Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Sat, 19 Oct 2024 15:14:01 +0900 Subject: [PATCH 060/108] [SPARK-50030][PYTHON][CONNECT] API compatibility check for Window ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Window functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48541 from itholic/SPARK-50030. Authored-by: Haejoon Lee Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/window.py | 18 +++++----- .../sql/tests/test_connect_compatibility.py | 36 +++++++++++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index b1bf080ded315..bf6d60df63505 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -84,23 +84,21 @@ def __init__( self._orderSpec = orderSpec self._frame = frame - def partitionBy( - self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]] - ) -> ParentWindowSpec: + def partitionBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": return WindowSpec( partitionSpec=[c._expr for c in _to_cols(cols)], # type: ignore[misc] orderSpec=self._orderSpec, frame=self._frame, ) - def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: + def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": return WindowSpec( partitionSpec=self._partitionSpec, orderSpec=[cast(SortOrder, F._sort_col(c)._expr) for c in _to_cols(cols)], frame=self._frame, ) - def rowsBetween(self, start: int, end: int) -> ParentWindowSpec: + def rowsBetween(self, start: int, end: int) -> "WindowSpec": if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding if end >= Window._FOLLOWING_THRESHOLD: @@ -112,7 +110,7 @@ def rowsBetween(self, start: int, end: int) -> ParentWindowSpec: frame=WindowFrame(isRowFrame=True, start=start, end=end), ) - def rangeBetween(self, start: int, end: int) -> ParentWindowSpec: + def rangeBetween(self, start: int, end: int) -> "WindowSpec": if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding if end >= Window._FOLLOWING_THRESHOLD: @@ -141,19 +139,19 @@ class Window(ParentWindow): _spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None) @staticmethod - def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: + def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": return Window._spec.partitionBy(*cols) @staticmethod - def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: + def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": return Window._spec.orderBy(*cols) @staticmethod - def rowsBetween(start: int, end: int) -> ParentWindowSpec: + def rowsBetween(start: int, end: int) -> "WindowSpec": return Window._spec.rowsBetween(start, end) @staticmethod - def rangeBetween(start: int, end: int) -> ParentWindowSpec: + def rangeBetween(start: int, end: int) -> "WindowSpec": return Window._spec.rangeBetween(start, end) diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index efef85862633e..f081385f44894 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -28,6 +28,8 @@ from pyspark.sql.readwriter import DataFrameReader as ClassicDataFrameReader from pyspark.sql.readwriter import DataFrameWriter as ClassicDataFrameWriter from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2 +from pyspark.sql.window import Window as ClassicWindow +from pyspark.sql.window import WindowSpec as ClassicWindowSpec if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -37,6 +39,8 @@ from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader from pyspark.sql.connect.readwriter import DataFrameWriter as ConnectDataFrameWriter from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2 + from pyspark.sql.connect.window import Window as ConnectWindow + from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec class ConnectCompatibilityTestsMixin: @@ -303,6 +307,38 @@ def test_dataframe_writer_v2_compatibility(self): expected_missing_classic_methods, ) + def test_window_compatibility(self): + """Test Window compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicWindow, + ConnectWindow, + "Window", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + + def test_window_spec_compatibility(self): + """Test WindowSpec compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicWindowSpec, + ConnectWindowSpec, + "WindowSpec", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From 25b03f9d3471ea57ba3d18a7d7fbe0be05a306eb Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Sat, 19 Oct 2024 19:56:37 +0900 Subject: [PATCH 061/108] [SPARK-50023][PYTHON][CONNECT] API compatibility check for Functions ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48536 from itholic/SPARK-50023. Authored-by: Haejoon Lee Signed-off-by: Hyukjin Kwon --- .../pyspark/sql/connect/functions/builtin.py | 7 ++++--- python/pyspark/sql/functions/builtin.py | 8 ++++---- .../sql/tests/test_connect_compatibility.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 9341442a1733b..1e3d41825f06c 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1498,7 +1498,7 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> lead.__doc__ = pysparkfuncs.lead.__doc__ -def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = None) -> Column: +def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column: if ignoreNulls is None: return _invoke_function("nth_value", _to_col(col), lit(offset)) else: @@ -2236,7 +2236,7 @@ def size(col: "ColumnOrName") -> Column: def slice( - col: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] + x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] ) -> Column: start = _enum_to_value(start) if isinstance(start, (Column, str)): @@ -2260,7 +2260,7 @@ def slice( messageParameters={"arg_name": "length", "arg_type": type(length).__name__}, ) - return _invoke_function_over_columns("slice", col, _start, _length) + return _invoke_function_over_columns("slice", x, _start, _length) slice.__doc__ = pysparkfuncs.slice.__doc__ @@ -4195,6 +4195,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column: def udf( f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None, returnType: "DataTypeOrString" = StringType(), + *, useArrow: Optional[bool] = None, ) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]: if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 55da50fd4a5a5..dbc66cab3f9b3 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -41,7 +41,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame as ParentDataFrame from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from_numpy_type # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 @@ -5590,7 +5590,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C @_try_remote_functions -def broadcast(df: DataFrame) -> DataFrame: +def broadcast(df: "ParentDataFrame") -> "ParentDataFrame": """ Marks a DataFrame as small enough for use in broadcast joins. @@ -5621,7 +5621,7 @@ def broadcast(df: DataFrame) -> DataFrame: from py4j.java_gateway import JVMView sc = _get_active_spark_context() - return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession) + return ParentDataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession) @_try_remote_functions @@ -9678,7 +9678,7 @@ def from_utc_timestamp(timestamp: "ColumnOrName", tz: Union[Column, str]) -> Col @_try_remote_functions -def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: +def to_utc_timestamp(timestamp: "ColumnOrName", tz: Union[Column, str]) -> Column: """ This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index f081385f44894..3ebb6b7aea7d0 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -30,6 +30,7 @@ from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2 from pyspark.sql.window import Window as ClassicWindow from pyspark.sql.window import WindowSpec as ClassicWindowSpec +import pyspark.sql.functions as ClassicFunctions if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -41,6 +42,7 @@ from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2 from pyspark.sql.connect.window import Window as ConnectWindow from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec + import pyspark.sql.connect.functions as ConnectFunctions class ConnectCompatibilityTestsMixin: @@ -339,6 +341,22 @@ def test_window_spec_compatibility(self): expected_missing_classic_methods, ) + def test_functions_compatibility(self): + """Test Functions compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = {"check_dependencies"} + self.check_compatibility( + ClassicFunctions, + ConnectFunctions, + "Functions", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From d5550f65a5da4c166e3efe924c5713b55405f67d Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Sun, 20 Oct 2024 08:36:36 +0800 Subject: [PATCH 062/108] [SPARK-49530][PYTHON][CONNECT] Support kde/density plots ### What changes were proposed in this pull request? Support kde/density plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. kde/density plots are supported as shown below. ```py >>> data = [ ... (1.0, 4.0), ... (2.0, 4.0), ... (2.5, 4.5), ... (3.0, 5.0), ... (3.5, 5.5), ... (4.0, 6.0), ... (5.0, 6.0) ... ] >>> columns = ["x", "y"] >>> df = spark.createDataFrame(data, columns) >>> fig1 = df.plot.kde(column=["x", "y"], bw_method=0.3, ind=100) >>> fig1.show() # see below >>> fig2 = df.plot(kind="kde", column="x", bw_method=0.3, ind=20) >>> fig2.show() # see below ``` fig1: ![newplot (23)](https://github.com/user-attachments/assets/2cb84a78-7d92-43b5-afec-df83e5c55f5c) fig2: ![newplot (22)](https://github.com/user-attachments/assets/90a70770-8d05-4c81-8f02-4f954c2c689e) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48492 from xinrong-meng/kde. Authored-by: Xinrong Meng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/pandas/utils.py | 30 ++++ python/pyspark/sql/plot/core.py | 131 +++++++++++++++++- python/pyspark/sql/plot/plotly.py | 47 ++++++- .../sql/tests/plot/test_frame_plot_plotly.py | 34 ++++- python/pyspark/testing/sqlutils.py | 8 ++ 5 files changed, 247 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py index d080448cfd3a1..5849ae0edd6d9 100644 --- a/python/pyspark/sql/pandas/utils.py +++ b/python/pyspark/sql/pandas/utils.py @@ -94,3 +94,33 @@ def require_minimum_pyarrow_version() -> None: errorClass="ARROW_LEGACY_IPC_FORMAT", messageParameters={}, ) + + +def require_minimum_numpy_version() -> None: + """Raise ImportError if minimum version of NumPy is not installed""" + minimum_numpy_version = "1.21" + + try: + import numpy + + have_numpy = True + except ImportError as error: + have_numpy = False + raised_error = error + if not have_numpy: + raise PySparkImportError( + errorClass="PACKAGE_NOT_INSTALLED", + messageParameters={ + "package_name": "NumPy", + "minimum_version": str(minimum_numpy_version), + }, + ) from raised_error + if LooseVersion(numpy.__version__) < LooseVersion(minimum_numpy_version): + raise PySparkImportError( + errorClass="UNSUPPORTED_PACKAGE_VERSION", + messageParameters={ + "package_name": "NumPy", + "minimum_version": str(minimum_numpy_version), + "current_version": str(numpy.__version__), + }, + ) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index e61af4ae3fa5d..f44c0768d4337 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -15,18 +15,26 @@ # limitations under the License. # +import math + from typing import Any, TYPE_CHECKING, List, Optional, Union from types import ModuleType -from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.errors import ( + PySparkRuntimeError, + PySparkTypeError, + PySparkValueError, +) from pyspark.sql import Column, functions as F from pyspark.sql.types import NumericType from pyspark.sql.utils import is_remote, require_minimum_plotly_version +from pandas.core.dtypes.inference import is_integer if TYPE_CHECKING: from pyspark.sql import DataFrame, Row from pyspark.sql._typing import ColumnOrName import pandas as pd + import numpy as np from plotly.graph_objs import Figure @@ -398,6 +406,127 @@ def box( """ return self(kind="box", column=column, precision=precision, **kwargs) + def kde( + self, + column: Union[str, List[str]], + bw_method: Union[int, float], + ind: Union["np.ndarray", int, None] = None, + **kwargs: Any, + ) -> "Figure": + """ + Generate Kernel Density Estimate plot using Gaussian kernels. + + In statistics, kernel density estimation (KDE) is a non-parametric way to + estimate the probability density function (PDF) of a random variable. This + function uses Gaussian kernels and includes automatic bandwidth determination. + + Parameters + ---------- + column: str or list of str + Column name or list of names to be used for creating the kde plot. + bw_method : int or float + The method used to calculate the estimator bandwidth. + See KernelDensity in PySpark for more information. + ind : NumPy array or integer, optional + Evaluation points for the estimated PDF. If None (default), + 1000 equally spaced points are used. If `ind` is a NumPy array, the + KDE is evaluated at the points passed. If `ind` is an integer, + `ind` number of equally spaced points are used. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + >>> columns = ["length", "width", "species"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.kde(column=["length", "width"], bw_method=0.3) # doctest: +SKIP + >>> df.plot.kde(column="length", bw_method=0.3) # doctest: +SKIP + """ + return self(kind="kde", column=column, bw_method=bw_method, ind=ind, **kwargs) + + +class PySparkKdePlotBase: + @staticmethod + def get_ind(sdf: "DataFrame", ind: Union["np.ndarray", int, None]) -> "np.ndarray": + from pyspark.sql.pandas.utils import require_minimum_numpy_version + + require_minimum_numpy_version() + import numpy as np + + def calc_min_max() -> "Row": + if len(sdf.columns) > 1: + min_col = F.least(*map(F.min, sdf)) # type: ignore + max_col = F.greatest(*map(F.max, sdf)) # type: ignore + else: + min_col = F.min(sdf.columns[-1]) + max_col = F.max(sdf.columns[-1]) + return sdf.select(min_col, max_col).first() # type: ignore + + if ind is None: + min_val, max_val = calc_min_max() + sample_range = max_val - min_val + ind = np.linspace( + min_val - 0.5 * sample_range, + max_val + 0.5 * sample_range, + 1000, + ) + elif is_integer(ind): + min_val, max_val = calc_min_max() + sample_range = max_val - min_val + ind = np.linspace( + min_val - 0.5 * sample_range, + max_val + 0.5 * sample_range, + ind, + ) + return ind # type: ignore + + @staticmethod + def compute_kde_col( + input_col: Column, + bw_method: Union[int, float], + ind: "np.ndarray", + ) -> Column: + # refers to org.apache.spark.mllib.stat.KernelDensity + assert bw_method is not None and isinstance( + bw_method, (int, float) + ), "'bw_method' must be set as a scalar number." + + assert ind is not None, "'ind' must be a scalar array." + + bandwidth = float(bw_method) + points = [float(i) for i in ind] + log_std_plus_half_log2_pi = math.log(bandwidth) + 0.5 * math.log(2 * math.pi) + + def norm_pdf( + mean: Column, + std: Column, + log_std_plus_half_log2_pi: Column, + x: Column, + ) -> Column: + x0 = x - mean + x1 = x0 / std + log_density = -0.5 * x1 * x1 - log_std_plus_half_log2_pi + return F.exp(log_density) + + return F.array( + [ + F.avg( + norm_pdf( + input_col.cast("double"), + F.lit(bandwidth), + F.lit(log_std_plus_half_log2_pi), + F.lit(point), + ) + ) + for point in points + ] + ) + class PySparkBoxPlotBase: @staticmethod diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 71d40720e874d..884ee1da28aa4 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any from pyspark.errors import PySparkValueError -from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase +from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase, PySparkKdePlotBase if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -32,6 +32,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": return plot_pie(data, **kwargs) if kind == "box": return plot_box(data, **kwargs) + if kind == "kde" or kind == "density": + return plot_kde(data, **kwargs) return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) @@ -118,3 +120,46 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": fig["layout"]["yaxis"]["title"] = "value" return fig + + +def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure": + from pyspark.sql.pandas.utils import require_minimum_pandas_version + + require_minimum_pandas_version() + + import pandas as pd + from plotly import express + + if "color" not in kwargs: + kwargs["color"] = "names" + + bw_method = kwargs.pop("bw_method", None) + colnames = kwargs.pop("column", None) + if isinstance(colnames, str): + colnames = [colnames] + ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None)) + + kde_cols = [ + PySparkKdePlotBase.compute_kde_col( + input_col=data[col_name], + ind=ind, + bw_method=bw_method, + ).alias(f"kde_{i}") + for i, col_name in enumerate(colnames) + ] + kde_results = data.select(*kde_cols).first() + pdf = pd.concat( + [ + pd.DataFrame( # type: ignore + { + "Density": kde_result, + "names": col_name, + "index": ind, + } + ) + for col_name, kde_result in zip(colnames, list(kde_results)) # type: ignore[arg-type] + ] + ) + fig = express.line(pdf, x="index", y="Density", **kwargs) + fig["layout"]["xaxis"]["title"] = None + return fig diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index d870cdbf9959b..9764b4a277273 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -20,7 +20,13 @@ import pyspark.sql.plot # noqa: F401 from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_plotly, + have_numpy, + plotly_requirement_message, + numpy_requirement_message, +) @unittest.skipIf(not have_plotly, plotly_requirement_message) @@ -375,6 +381,32 @@ def test_box_plot(self): }, ) + @unittest.skipIf(not have_numpy, numpy_requirement_message) + def test_kde_plot(self): + fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5) + expected_fig_data = { + "mode": "lines", + "name": "math_score", + "orientation": "v", + "xaxis": "x", + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5) + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "mode": "lines", + "name": "english_score", + "orientation": "v", + "xaxis": "x", + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"])) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 00ad40e68bd7c..dab382c37f42b 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -55,6 +55,13 @@ plotly_requirement_message = str(e) have_plotly = plotly_requirement_message is None +numpy_requirement_message = None +try: + import numpy +except ImportError as e: + numpy_requirement_message = str(e) +have_numpy = numpy_requirement_message is None + from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils @@ -63,6 +70,7 @@ have_pandas = pandas_requirement_message is None have_pyarrow = pyarrow_requirement_message is None test_compiled = test_not_compiled_message is None +have_numpy = numpy_requirement_message is None class UTCOffsetTimezone(datetime.tzinfo): From 64218d71951ee28e4a76a309e59c6284cfe87d7e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 20 Oct 2024 10:09:16 +0900 Subject: [PATCH 063/108] [MINOR][DOCS] Add Development Version of docs in README.md ### What changes were proposed in this pull request? This PR proposes to add Development Version of docs in README.md ### Why are the changes needed? For developers to easily navigate documentation dev version. ### Does this PR introduce _any_ user-facing change? No, dev-only. ### How was this patch tested? Manually ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48555 from HyukjinKwon/minor-link. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b9a20075f6a17..552b71215cb92 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ rich set of higher-level tools including Spark SQL for SQL and DataFrames, pandas API on Spark for pandas workloads, MLlib for machine learning, GraphX for graph processing, and Structured Streaming for stream processing. - +- Official version: +- Development version: [![GitHub Actions Build](https://github.com/apache/spark/actions/workflows/build_main.yml/badge.svg)](https://github.com/apache/spark/actions/workflows/build_main.yml) [![PySpark Coverage](https://codecov.io/gh/apache/spark/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/spark) From a31da6a37ef6369c271c4144dafc723a57c480c8 Mon Sep 17 00:00:00 2001 From: Niranjan Jayakar Date: Sun, 20 Oct 2024 10:14:10 +0900 Subject: [PATCH 064/108] [SPARK-50024][PYTHON][CONNECT] Switch to use logger instead of warnings module in client ### What changes were proposed in this pull request? ReleaseExecute, in some cases, can fail since the operation may already have been released or dropped by the server. The API call is best effort. By logging the error using the Python warnings module, we produce a user facing warning and can be confusing. The original intention is to add it to the log output for debugging when necessary. Switch to use a standard logger instead. The warnings module should generally be used to warn users of their choices - deprecated API, unapplied option, etc. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Ran pyspark locally and checked that the log statements are printed. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48519 from nija-at/log-insteadof-warn. Authored-by: Niranjan Jayakar Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/__init__.py | 2 +- python/pyspark/sql/connect/client/artifact.py | 2 +- python/pyspark/sql/connect/client/core.py | 2 +- python/pyspark/sql/connect/client/reattach.py | 6 +++--- python/pyspark/sql/connect/client/retries.py | 2 +- python/pyspark/sql/connect/{client => }/logging.py | 0 python/pyspark/sql/connect/plan.py | 4 ++-- python/pyspark/sql/connect/session.py | 13 +++++++------ 8 files changed, 16 insertions(+), 15 deletions(-) rename python/pyspark/sql/connect/{client => }/logging.py (100%) diff --git a/python/pyspark/sql/connect/client/__init__.py b/python/pyspark/sql/connect/client/__init__.py index 38523352e5b4a..40c05d4905c76 100644 --- a/python/pyspark/sql/connect/client/__init__.py +++ b/python/pyspark/sql/connect/client/__init__.py @@ -20,4 +20,4 @@ check_dependencies(__name__) from pyspark.sql.connect.client.core import * # noqa: F401,F403 -from pyspark.sql.connect.client.logging import getLogLevel # noqa: F401 +from pyspark.sql.connect.logging import getLogLevel # noqa: F401 diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index dd243698136ec..ac33233a00ff3 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -16,7 +16,7 @@ # from pyspark.errors import PySparkRuntimeError, PySparkValueError from pyspark.sql.connect.utils import check_dependencies -from pyspark.sql.connect.client.logging import logger +from pyspark.sql.connect.logging import logger check_dependencies(__name__) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index adba1b42a8bd6..3de4255054053 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -65,7 +65,7 @@ from pyspark.resource.information import ResourceInformation from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics from pyspark.sql.connect.client.artifact import ArtifactManager -from pyspark.sql.connect.client.logging import logger +from pyspark.sql.connect.logging import logger from pyspark.sql.connect.profiler import ConnectProfilerCollector from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index e0c7cc448933d..e6dba6e0073f7 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -20,7 +20,6 @@ check_dependencies(__name__) from threading import RLock -import warnings import uuid from collections.abc import Generator from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar @@ -30,6 +29,7 @@ import grpc from grpc_status import rpc_status +from pyspark.sql.connect.logging import logger import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib from pyspark.errors import PySparkRuntimeError @@ -206,7 +206,7 @@ def target() -> None: with attempt: self._stub.ReleaseExecute(request, metadata=self._metadata) except Exception as e: - warnings.warn(f"ReleaseExecute failed with exception: {e}.") + logger.warn(f"ReleaseExecute failed with exception: {e}.") with self._lock: if self._release_thread_pool_instance is not None: @@ -231,7 +231,7 @@ def target() -> None: with attempt: self._stub.ReleaseExecute(request, metadata=self._metadata) except Exception as e: - warnings.warn(f"ReleaseExecute failed with exception: {e}.") + logger.warn(f"ReleaseExecute failed with exception: {e}.") with self._lock: if self._release_thread_pool_instance is not None: diff --git a/python/pyspark/sql/connect/client/retries.py b/python/pyspark/sql/connect/client/retries.py index f2006ab5ec8bb..e27100133b5ae 100644 --- a/python/pyspark/sql/connect/client/retries.py +++ b/python/pyspark/sql/connect/client/retries.py @@ -21,7 +21,7 @@ import typing from typing import Optional, Callable, Generator, List, Type from types import TracebackType -from pyspark.sql.connect.client.logging import logger +from pyspark.sql.connect.logging import logger from pyspark.errors import PySparkRuntimeError, RetriesExceeded """ diff --git a/python/pyspark/sql/connect/client/logging.py b/python/pyspark/sql/connect/logging.py similarity index 100% rename from python/pyspark/sql/connect/client/logging.py rename to python/pyspark/sql/connect/logging.py diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index b74f863db1e83..b8268d46b3325 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -40,7 +40,6 @@ import pickle from threading import Lock from inspect import signature, isclass -import warnings import pyarrow as pa @@ -50,6 +49,7 @@ import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column +from pyspark.sql.connect.logging import logger from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.expressions import Expression @@ -596,7 +596,7 @@ def __del__(self) -> None: metadata = session.client._builder.metadata() channel(req, metadata=metadata) # type: ignore[arg-type] except Exception as e: - warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") + logger.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") class Hint(LogicalPlan): diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index cacb479229bb7..a4047f09401ea 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -52,6 +52,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as ParentDataFrame +from pyspark.sql.connect.logging import logger from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder from pyspark.sql.connect.conf import RuntimeConf from pyspark.sql.connect.plan import ( @@ -218,8 +219,8 @@ def _apply_options(self, session: "SparkSession") -> None: # simply ignore it for now. try: session.conf.set(k, v) - except Exception: - pass + except Exception as e: + logger.warn(f"Failed to set configuration {k} due to {e}") with self._lock: for k, v in self._options.items(): @@ -232,7 +233,7 @@ def _apply_options(self, session: "SparkSession") -> None: try: session.conf.set(k, v) except Exception as e: - warnings.warn(str(e)) + logger.warn(f"Failed to set configuration {k} due to {e}") def create(self) -> "SparkSession": has_channel_builder = self._channel_builder is not None @@ -860,12 +861,12 @@ def stop(self) -> None: try: self.client.release_session() except Exception as e: - warnings.warn(f"session.stop(): Session could not be released. Error: ${e}") + logger.warn(f"session.stop(): Session could not be released. Error: ${e}") try: self.client.close() except Exception as e: - warnings.warn(f"session.stop(): Client could not be closed. Error: ${e}") + logger.warn(f"session.stop(): Client could not be closed. Error: ${e}") if self is SparkSession._default_session: SparkSession._default_session = None @@ -881,7 +882,7 @@ def stop(self) -> None: try: PySparkSession._activeSession.stop() except Exception as e: - warnings.warn( + logger.warn( "session.stop(): Local Spark Connect Server could not be stopped. " f"Error: ${e}" ) From 1ca2d5620d006117937c4ef145e2730af95d490b Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 20 Oct 2024 17:27:18 +0900 Subject: [PATCH 065/108] [MINOR][TESTS] Fix appname from `sql.protobuf.functions` to `sql.connect.protobuf.functions` ### What changes were proposed in this pull request? This PR proposes to fix appname from `sql.protobuf.functions` to `sql.connect.protobuf.functions` ### Why are the changes needed? For consistency ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually checked. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48562 from HyukjinKwon/minor-test-title. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/protobuf/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py index 07e9b4b8c6861..ba43f94ce1eeb 100644 --- a/python/pyspark/sql/connect/protobuf/functions.py +++ b/python/pyspark/sql/connect/protobuf/functions.py @@ -142,7 +142,7 @@ def _test() -> None: globs = pyspark.sql.connect.protobuf.functions.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.protobuf.functions tests") + PySparkSession.builder.appName("sql.connect.protobuf.functions tests") .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .getOrCreate() ) From 4883689e90bdeddddc8dcadf859fd2979a58b40a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 20 Oct 2024 17:28:22 +0900 Subject: [PATCH 066/108] [SPARK-50042][PYTHON] Upgrade numpy 2 for python linter ### What changes were proposed in this pull request? Upgrade numpy for python linter ### Why are the changes needed? Upgrade numpy for python linter ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48565 from zhengruifeng/infra_numpy_lint. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .github/workflows/build_and_test.yml | 2 +- python/pyspark/ml/linalg/__init__.py | 4 ++-- python/pyspark/mllib/linalg/__init__.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 14d93a498fc59..c43c7df22039b 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -723,7 +723,7 @@ jobs: # See 'ipython_genutils' in SPARK-38517 # See 'docutils<0.18.0' in SPARK-39421 python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy==1.26.4' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ + ipython ipython_genutils sphinx_plotly_directive numpy pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index d470f8b8b5c46..cedd3b04564ec 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -706,12 +706,12 @@ def dot(self, other: Iterable[float]) -> np.float64: elif isinstance(other, SparseVector): # Find out common indices. - self_cmind = np.in1d(self.indices, other.indices, assume_unique=True) + self_cmind = np.isin(self.indices, other.indices, assume_unique=True) self_values = self.values[self_cmind] if self_values.size == 0: return np.float64(0.0) else: - other_cmind = np.in1d(other.indices, self.indices, assume_unique=True) + other_cmind = np.isin(other.indices, self.indices, assume_unique=True) return np.dot(self_values, other.values[other_cmind]) else: diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 02cef36c11c46..40f0255a91bbe 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -813,12 +813,12 @@ def dot(self, other: Iterable[float]) -> np.float64: elif isinstance(other, SparseVector): # Find out common indices. - self_cmind = np.in1d(self.indices, other.indices, assume_unique=True) + self_cmind = np.isin(self.indices, other.indices, assume_unique=True) self_values = self.values[self_cmind] if self_values.size == 0: return np.float64(0.0) else: - other_cmind = np.in1d(other.indices, self.indices, assume_unique=True) + other_cmind = np.isin(other.indices, self.indices, assume_unique=True) return np.dot(self_values, other.values[other_cmind]) else: From 164a0f46d657a42fa5c47c05a515c50580bc8f17 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 20 Oct 2024 19:53:49 +0900 Subject: [PATCH 067/108] [MINOR][DOCS][PYTHON] Clarify that profilers show the accumulated results ### What changes were proposed in this pull request? This PR proposes to clarify that profilers show the accumulated results ### Why are the changes needed? To make it clear that the results are accumulated. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user-facing documentation. ### How was this patch tested? CI should verify it. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48563 from HyukjinKwon/minor-docs. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/profiler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py index bd204877c0f55..6924cde9a292a 100644 --- a/python/pyspark/sql/profiler.py +++ b/python/pyspark/sql/profiler.py @@ -315,6 +315,12 @@ def show(self, id: Optional[int] = None, *, type: Optional[str] = None) -> None: A UDF ID to be shown. If not specified, all the results will be shown. type : str, optional The profiler type, which can be either "perf" or "memory". + + Notes + ----- + The results are gathered from all Python executions. For example, if there are + 8 tasks, each processing 1,000 rows, the total output will display the results + for 8,000 rows. """ if type == "memory": self.profiler_collector.show_memory_profiles(id) From ae75cac50d374f388ad79eac35dfcef9e2d07f83 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 20 Oct 2024 21:48:13 +0900 Subject: [PATCH 068/108] [SPARK-50040][PYTHON][TESTS] Make pysaprk-connect tests passing without optional dependencies ### What changes were proposed in this pull request? This PR proposes to make pysaprk-connect tests passing without optional dependencies ### Why are the changes needed? To make the tests passing without optional dependencies. See https://github.com/apache/spark/actions/runs/11420354598/job/31775990587 ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually ran it locally ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48561 from HyukjinKwon/SPARK-50040. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../connect/test_connect_classification.py | 4 +- .../pyspark/sql/connect/functions/__init__.py | 6 ++- python/pyspark/sql/connect/merge.py | 7 ++- python/pyspark/sql/connect/observation.py | 4 ++ python/pyspark/sql/connect/utils.py | 2 +- .../streaming/test_parity_foreach_batch.py | 6 ++- .../sql/tests/connect/test_connect_column.py | 2 +- .../tests/connect/test_connect_creation.py | 2 +- .../test_connect_dataframe_property.py | 7 +-- .../sql/tests/connect/test_connect_error.py | 2 +- .../sql/tests/connect/test_connect_session.py | 51 ++++++++++--------- .../sql/tests/connect/test_connect_stat.py | 8 +-- .../sql/tests/connect/test_parity_udtf.py | 7 ++- .../pyspark/sql/tests/plot/test_frame_plot.py | 16 ++++-- .../sql/tests/plot/test_frame_plot_plotly.py | 10 +++- python/pyspark/testing/__init__.py | 45 ++++++++++++++++ python/pyspark/testing/connectutils.py | 50 +++++------------- 17 files changed, 137 insertions(+), 92 deletions(-) diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index 8083090523a0e..910d2d2ec42f9 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -21,6 +21,7 @@ from pyspark.util import is_remote_only from pyspark.sql import SparkSession +from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin from pyspark.testing.connectutils import should_test_connect, connect_requirement_message torch_requirement_message = "torch is required" @@ -30,9 +31,6 @@ except ImportError: have_torch = False -if should_test_connect: - from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin - @unittest.skipIf( not should_test_connect or not have_torch or is_remote_only(), diff --git a/python/pyspark/sql/connect/functions/__init__.py b/python/pyspark/sql/connect/functions/__init__.py index e0179d4d56cf8..087a51e8616b9 100644 --- a/python/pyspark/sql/connect/functions/__init__.py +++ b/python/pyspark/sql/connect/functions/__init__.py @@ -16,6 +16,8 @@ # """PySpark Functions with Spark Connect""" +from pyspark.testing import should_test_connect -from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403 -from pyspark.sql.connect.functions import partitioning # noqa: F401,F403 +if should_test_connect: + from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403 + from pyspark.sql.connect.functions import partitioning # noqa: F401,F403 diff --git a/python/pyspark/sql/connect/merge.py b/python/pyspark/sql/connect/merge.py index 9c3b3e4370a40..295e6089e092e 100644 --- a/python/pyspark/sql/connect/merge.py +++ b/python/pyspark/sql/connect/merge.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) import sys from typing import Dict, Optional, TYPE_CHECKING, List, Callable @@ -235,12 +238,12 @@ def _test() -> None: globs = pyspark.sql.connect.merge.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.dataframe tests") + PySparkSession.builder.appName("sql.connect.merge tests") .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) (failure_count, test_count) = doctest.testmod( - pyspark.sql.merge, + pyspark.sql.connect.merge, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, ) diff --git a/python/pyspark/sql/connect/observation.py b/python/pyspark/sql/connect/observation.py index e4b9b8a2d4fba..bfb8a0a9355fe 100644 --- a/python/pyspark/sql/connect/observation.py +++ b/python/pyspark/sql/connect/observation.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + from typing import Any, Dict, Optional import uuid diff --git a/python/pyspark/sql/connect/utils.py b/python/pyspark/sql/connect/utils.py index ce57b490c4532..a2511836816c9 100644 --- a/python/pyspark/sql/connect/utils.py +++ b/python/pyspark/sql/connect/utils.py @@ -22,7 +22,7 @@ def check_dependencies(mod_name: str) -> None: - if mod_name == "__main__": + if mod_name == "__main__" or mod_name == "pyspark.sql.connect.utils": from pyspark.testing.connectutils import should_test_connect, connect_requirement_message if not should_test_connect: diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py index d79bfef2426a4..9d28ec0e19702 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py @@ -18,9 +18,11 @@ import unittest from pyspark.sql.tests.streaming.test_streaming_foreach_batch import StreamingTestsForeachBatchMixin -from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.errors import PySparkPicklingError -from pyspark.errors.exceptions.connect import SparkConnectGrpcException + +if should_test_connect: + from pyspark.errors.exceptions.connect import SparkConnectGrpcException class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 509f381f97fec..60ddcb6f22a54 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -40,7 +40,6 @@ BooleanType, ) from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.errors.exceptions.connect import SparkConnectException from pyspark.testing.connectutils import should_test_connect from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase @@ -61,6 +60,7 @@ JVM_LONG_MIN, JVM_LONG_MAX, ) + from pyspark.errors.exceptions.connect import SparkConnectException class SparkConnectColumnTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py b/python/pyspark/sql/tests/connect/test_connect_creation.py index cf6c2e86d2f5b..5352913f6609d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_creation.py +++ b/python/pyspark/sql/tests/connect/test_connect_creation.py @@ -35,7 +35,6 @@ from pyspark.testing.sqlutils import MyObject, PythonOnlyUDT from pyspark.testing.connectutils import should_test_connect -from pyspark.errors.exceptions.connect import ParseException from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: @@ -43,6 +42,7 @@ import numpy as np from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF + from pyspark.errors.exceptions.connect import ParseException class SparkConnectCreationTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index c712e5d6efcb6..1a8c7190e31a6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -19,11 +19,9 @@ from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType from pyspark.sql.utils import is_remote - from pyspark.sql import functions as SF -from pyspark.sql.connect import functions as CF - from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.testing.connectutils import should_test_connect from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -38,6 +36,9 @@ if have_pandas: import pandas as pd +if should_test_connect: + from pyspark.sql.connect import functions as CF + class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase): def test_cached_property_is_copied(self): diff --git a/python/pyspark/sql/tests/connect/test_connect_error.py b/python/pyspark/sql/tests/connect/test_connect_error.py index 685e95a69ee74..01047741f6740 100644 --- a/python/pyspark/sql/tests/connect/test_connect_error.py +++ b/python/pyspark/sql/tests/connect/test_connect_error.py @@ -22,13 +22,13 @@ from pyspark.sql.types import Row from pyspark.testing.connectutils import should_test_connect from pyspark.errors import PySparkTypeError -from pyspark.errors.exceptions.connect import AnalysisException from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect import functions as CF from pyspark.sql.connect.column import Column + from pyspark.errors.exceptions.connect import AnalysisException class SparkConnectErrorTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 4ddefc7385839..0028ecb95830d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -27,24 +27,24 @@ RetriesExceeded, ) from pyspark.sql import SparkSession as PySparkSession -from pyspark.sql.connect.client.retries import RetryPolicy from pyspark.testing.connectutils import ( should_test_connect, ReusedConnectTestCase, connect_requirement_message, ) -from pyspark.errors.exceptions.connect import ( - AnalysisException, - SparkConnectException, - SparkUpgradeException, -) if should_test_connect: import grpc from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder from pyspark.sql.connect.client.core import Retrying, SparkConnectClient + from pyspark.sql.connect.client.retries import RetryPolicy + from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, + SparkUpgradeException, + ) @unittest.skipIf(is_remote_only(), "Session creation different from local mode") @@ -282,6 +282,7 @@ def test_stop_invalid_session(self): # SPARK-47986 session.stop() +@unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: self.spark = ( @@ -303,31 +304,31 @@ def test_config(self): self.assertEqual(self.spark.conf.get("integer"), "1") -class TestError(grpc.RpcError, Exception): - def __init__(self, code: grpc.StatusCode): - self._code = code - - def code(self): - return self._code +if should_test_connect: + class TestError(grpc.RpcError, Exception): + def __init__(self, code: grpc.StatusCode): + self._code = code -class TestPolicy(RetryPolicy): - # Put a small value for initial backoff so that tests don't spend - # Time waiting - def __init__(self, initial_backoff=10, **kwargs): - super().__init__(initial_backoff=initial_backoff, **kwargs) + def code(self): + return self._code - def can_retry(self, exception: BaseException): - return isinstance(exception, TestError) + class TestPolicy(RetryPolicy): + # Put a small value for initial backoff so that tests don't spend + # Time waiting + def __init__(self, initial_backoff=10, **kwargs): + super().__init__(initial_backoff=initial_backoff, **kwargs) + def can_retry(self, exception: BaseException): + return isinstance(exception, TestError) -class TestPolicySpecificError(TestPolicy): - def __init__(self, specific_code: grpc.StatusCode, **kwargs): - super().__init__(**kwargs) - self.specific_code = specific_code + class TestPolicySpecificError(TestPolicy): + def __init__(self, specific_code: grpc.StatusCode, **kwargs): + super().__init__(**kwargs) + self.specific_code = specific_code - def can_retry(self, exception: BaseException): - return exception.code() == self.specific_code + def can_retry(self, exception: BaseException): + return exception.code() == self.specific_code @unittest.skipIf(not should_test_connect, connect_requirement_message) diff --git a/python/pyspark/sql/tests/connect/test_connect_stat.py b/python/pyspark/sql/tests/connect/test_connect_stat.py index a2f23b44023d3..6e3cc2f58d814 100644 --- a/python/pyspark/sql/tests/connect/test_connect_stat.py +++ b/python/pyspark/sql/tests/connect/test_connect_stat.py @@ -19,15 +19,15 @@ from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.connectutils import should_test_connect -from pyspark.errors.exceptions.connect import ( - AnalysisException, - SparkConnectException, -) from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF + from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, + ) class SparkConnectStatTests(SparkConnectSQLTestCase): diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 2ea6ef8cc389d..6955e7377b4c4 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -17,6 +17,8 @@ import unittest from pyspark.testing.connectutils import should_test_connect +from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark import sql @@ -24,10 +26,7 @@ sql.udtf.UserDefinedTableFunction = UserDefinedTableFunction from pyspark.sql.connect.functions import lit, udtf - -from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException + from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py index 2a6971e896292..3221a408d153d 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -18,11 +18,21 @@ import unittest from pyspark.errors import PySparkValueError from pyspark.sql import Row -from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase -from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_plotly, + plotly_requirement_message, + have_pandas, + pandas_requirement_message, +) +if have_plotly and have_pandas: + from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase -@unittest.skipIf(not have_plotly, plotly_requirement_message) + +@unittest.skipIf( + not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message +) class DataFramePlotTestsMixin: def test_backend(self): accessor = self.spark.range(2).plot diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 9764b4a277273..a6005b6f7c4d9 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -18,7 +18,6 @@ import unittest from datetime import datetime -import pyspark.sql.plot # noqa: F401 from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -26,10 +25,17 @@ have_numpy, plotly_requirement_message, numpy_requirement_message, + have_pandas, + pandas_requirement_message, ) +if have_plotly and have_pandas: + import pyspark.sql.plot # noqa: F401 -@unittest.skipIf(not have_plotly, plotly_requirement_message) + +@unittest.skipIf( + not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message +) class DataFramePlotPlotlyTestsMixin: @property def sdf(self): diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py index 88853e925f801..2a20035e54898 100644 --- a/python/pyspark/testing/__init__.py +++ b/python/pyspark/testing/__init__.py @@ -14,6 +14,51 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import typing + from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual + +grpc_requirement_message = None +try: + import grpc +except ImportError as e: + grpc_requirement_message = str(e) +have_grpc = grpc_requirement_message is None + + +grpc_status_requirement_message = None +try: + import grpc_status +except ImportError as e: + grpc_status_requirement_message = str(e) +have_grpc_status = grpc_status_requirement_message is None + +googleapis_common_protos_requirement_message = None +try: + from google.rpc import error_details_pb2 +except ImportError as e: + googleapis_common_protos_requirement_message = str(e) +have_googleapis_common_protos = googleapis_common_protos_requirement_message is None + +graphviz_requirement_message = None +try: + import graphviz +except ImportError as e: + graphviz_requirement_message = str(e) +have_graphviz: bool = graphviz_requirement_message is None + +from pyspark.testing.utils import PySparkErrorTestUtils +from pyspark.testing.sqlutils import pandas_requirement_message, pyarrow_requirement_message + + +connect_requirement_message = ( + pandas_requirement_message + or pyarrow_requirement_message + or grpc_requirement_message + or googleapis_common_protos_requirement_message + or grpc_status_requirement_message +) +should_test_connect: str = typing.cast(str, connect_requirement_message is None) + __all__ = ["assertDataFrameEqual", "assertSchemaEqual"] diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 2f18cd8a6ccdc..7dea8a2103c3d 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -23,35 +23,18 @@ import uuid import contextlib -grpc_requirement_message = None -try: - import grpc -except ImportError as e: - grpc_requirement_message = str(e) -have_grpc = grpc_requirement_message is None - - -grpc_status_requirement_message = None -try: - import grpc_status -except ImportError as e: - grpc_status_requirement_message = str(e) -have_grpc_status = grpc_status_requirement_message is None - -googleapis_common_protos_requirement_message = None -try: - from google.rpc import error_details_pb2 -except ImportError as e: - googleapis_common_protos_requirement_message = str(e) -have_googleapis_common_protos = googleapis_common_protos_requirement_message is None - -graphviz_requirement_message = None -try: - import graphviz -except ImportError as e: - graphviz_requirement_message = str(e) -have_graphviz: bool = graphviz_requirement_message is None - +from pyspark.testing import ( + grpc_requirement_message, + have_grpc, + grpc_status_requirement_message, + have_grpc_status, + googleapis_common_protos_requirement_message, + have_googleapis_common_protos, + graphviz_requirement_message, + have_graphviz, + connect_requirement_message, + should_test_connect, +) from pyspark import Row, SparkConf from pyspark.util import is_remote_only from pyspark.testing.utils import PySparkErrorTestUtils @@ -64,15 +47,6 @@ from pyspark.sql.session import SparkSession as PySparkSession -connect_requirement_message = ( - pandas_requirement_message - or pyarrow_requirement_message - or grpc_requirement_message - or googleapis_common_protos_requirement_message - or grpc_status_requirement_message -) -should_test_connect: str = typing.cast(str, connect_requirement_message is None) - if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan From 76ea894bc69f07c3a971e3e2243306ebb8b02b75 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Sun, 20 Oct 2024 21:50:26 +0900 Subject: [PATCH 069/108] [SPARK-50039][CONNECT][PYTHON] API compatibility check for Grouping ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Grouping functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48560 from itholic/compat_grouping. Authored-by: Haejoon Lee Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/group.py | 10 +++++----- python/pyspark/sql/group.py | 18 +++++++++--------- python/pyspark/sql/pandas/group_ops.py | 12 ++++++------ .../sql/tests/test_connect_compatibility.py | 18 ++++++++++++++++++ 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 46f13f893c7fa..863461da10ec9 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -193,29 +193,29 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame": session=self._df._session, ) - def min(self, *cols: str) -> "DataFrame": + def min(self: "GroupedData", *cols: str) -> "DataFrame": return self._numeric_agg("min", list(cols)) min.__doc__ = PySparkGroupedData.min.__doc__ - def max(self, *cols: str) -> "DataFrame": + def max(self: "GroupedData", *cols: str) -> "DataFrame": return self._numeric_agg("max", list(cols)) max.__doc__ = PySparkGroupedData.max.__doc__ - def sum(self, *cols: str) -> "DataFrame": + def sum(self: "GroupedData", *cols: str) -> "DataFrame": return self._numeric_agg("sum", list(cols)) sum.__doc__ = PySparkGroupedData.sum.__doc__ - def avg(self, *cols: str) -> "DataFrame": + def avg(self: "GroupedData", *cols: str) -> "DataFrame": return self._numeric_agg("avg", list(cols)) avg.__doc__ = PySparkGroupedData.avg.__doc__ mean = avg - def count(self) -> "DataFrame": + def count(self: "GroupedData") -> "DataFrame": return self.agg(F._invoke_function("count", F.lit(1)).alias("count")) count.__doc__ = PySparkGroupedData.count.__doc__ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ac4ac02a36b16..94b4b64a0b6f0 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -32,7 +32,7 @@ def dfapi(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: - def _api(self: "GroupedData") -> DataFrame: + def _api(self: "GroupedData") -> "DataFrame": name = f.__name__ jdf = getattr(self._jgd, name)() return DataFrame(jdf, self.session) @@ -43,7 +43,7 @@ def _api(self: "GroupedData") -> DataFrame: def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: - def _api(self: "GroupedData", *cols: str) -> DataFrame: + def _api(self: "GroupedData", *cols: str) -> "DataFrame": from pyspark.sql.classic.column import _to_seq name = f.__name__ @@ -80,14 +80,14 @@ def __repr__(self) -> str: return super().__repr__() @overload - def agg(self, *exprs: Column) -> DataFrame: + def agg(self, *exprs: Column) -> "DataFrame": ... @overload - def agg(self, __exprs: Dict[str, str]) -> DataFrame: + def agg(self, __exprs: Dict[str, str]) -> "DataFrame": ... - def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: + def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame": """Compute aggregates and returns the result as a :class:`DataFrame`. The available aggregate functions can be: @@ -190,7 +190,7 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: return DataFrame(jdf, self.session) @dfapi - def count(self) -> DataFrame: # type: ignore[empty-body] + def count(self) -> "DataFrame": # type: ignore[empty-body] """Counts the number of records for each group. .. versionadded:: 1.3.0 @@ -241,7 +241,7 @@ def mean(self, *cols: str) -> DataFrame: # type: ignore[empty-body] """ @df_varargs_api - def avg(self, *cols: str) -> DataFrame: # type: ignore[empty-body] + def avg(self, *cols: str) -> "DataFrame": # type: ignore[empty-body] """Computes average values for each numeric columns for each group. :func:`mean` is an alias for :func:`avg`. @@ -292,7 +292,7 @@ def avg(self, *cols: str) -> DataFrame: # type: ignore[empty-body] """ @df_varargs_api - def max(self, *cols: str) -> DataFrame: # type: ignore[empty-body] + def max(self, *cols: str) -> "DataFrame": # type: ignore[empty-body] """Computes the max value for each numeric columns for each group. .. versionadded:: 1.3.0 @@ -336,7 +336,7 @@ def max(self, *cols: str) -> DataFrame: # type: ignore[empty-body] """ @df_varargs_api - def min(self, *cols: str) -> DataFrame: # type: ignore[empty-body] + def min(self, *cols: str) -> "DataFrame": # type: ignore[empty-body] """Computes the min value for each numeric column for each group. .. versionadded:: 1.3.0 diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 3173534c03c91..0d21edc73b81a 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -49,7 +49,7 @@ class PandasGroupedOpsMixin: can use this class. """ - def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> DataFrame: + def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame": """ It is an alias of :meth:`pyspark.sql.GroupedData.applyInPandas`; however, it takes a :meth:`pyspark.sql.functions.pandas_udf` whereas @@ -121,8 +121,8 @@ def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> DataFrame: return self.applyInPandas(udf.func, schema=udf.returnType) # type: ignore[attr-defined] def applyInPandas( - self, func: "PandasGroupedMapFunction", schema: Union[StructType, str] - ) -> DataFrame: + self, func: "PandasGroupedMapFunction", schema: Union["StructType", str] + ) -> "DataFrame": """ Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result as a `DataFrame`. @@ -246,7 +246,7 @@ def applyInPandasWithState( stateStructType: Union[StructType, str], outputMode: str, timeoutConf: str, - ) -> DataFrame: + ) -> "DataFrame": """ Applies the given function to each group of data, while maintaining a user-defined per-group state. The result Dataset will represent the flattened record returned by the @@ -684,8 +684,8 @@ def __init__(self, gd1: "GroupedData", gd2: "GroupedData"): self._gd2 = gd2 def applyInPandas( - self, func: "PandasCogroupedMapFunction", schema: Union[StructType, str] - ) -> DataFrame: + self, func: "PandasCogroupedMapFunction", schema: Union["StructType", str] + ) -> "DataFrame": """ Applies a function to each cogroup using pandas and returns the result as a `DataFrame`. diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 3ebb6b7aea7d0..72f139ac4768c 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -31,6 +31,7 @@ from pyspark.sql.window import Window as ClassicWindow from pyspark.sql.window import WindowSpec as ClassicWindowSpec import pyspark.sql.functions as ClassicFunctions +from pyspark.sql.group import GroupedData as ClassicGroupedData if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -43,6 +44,7 @@ from pyspark.sql.connect.window import Window as ConnectWindow from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec import pyspark.sql.connect.functions as ConnectFunctions + from pyspark.sql.connect.group import GroupedData as ConnectGroupedData class ConnectCompatibilityTestsMixin: @@ -357,6 +359,22 @@ def test_functions_compatibility(self): expected_missing_classic_methods, ) + def test_grouping_compatibility(self): + """Test Grouping compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = {"transformWithStateInPandas"} + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicGroupedData, + ConnectGroupedData, + "Grouping", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From f2f309970ff58b90a814eaa5735b3dbc8449f412 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Sun, 20 Oct 2024 18:55:44 +0200 Subject: [PATCH 070/108] [SPARK-50038][SQL] Assign appropriate error condition for `_LEGACY_ERROR_TEMP_0008`: `MERGE_WITHOUT_WHEN` ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for `_LEGACY_ERROR_TEMP_0008`: `MERGE_WITHOUT_WHEN` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48559 from itholic/SPARK-50038. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 11 ++++++----- .../apache/spark/sql/errors/QueryParsingErrors.scala | 2 +- .../spark/sql/catalyst/parser/DDLParserSuite.scala | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index fb1439cfe1a5e..28b2bc9a857db 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3422,6 +3422,12 @@ ], "sqlState" : "23K01" }, + "MERGE_WITHOUT_WHEN" : { + "message" : [ + "There must be at least one WHEN clause in a MERGE statement." + ], + "sqlState" : "42601" + }, "MISSING_AGGREGATION" : { "message" : [ "The non-aggregating expression is based on columns which are not participating in the GROUP BY clause.", @@ -5661,11 +5667,6 @@ "The number of inserted values cannot match the fields." ] }, - "_LEGACY_ERROR_TEMP_0008" : { - "message" : [ - "There must be at least one WHEN clause in a MERGE statement." - ] - }, "_LEGACY_ERROR_TEMP_0012" : { "message" : [ "DISTRIBUTE BY is not supported." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 0272d06ee1261..199a1ed868d6e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -60,7 +60,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def mergeStatementWithoutWhenClauseError(ctx: MergeIntoTableContext): Throwable = { - new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0008", ctx) + new ParseException(errorClass = "MERGE_WITHOUT_WHEN", ctx) } def nonLastMatchedClauseOmitConditionError(ctx: MergeIntoTableContext): Throwable = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 926beacc592a5..5e871208698af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2120,7 +2120,7 @@ class DDLParserSuite extends AnalysisTest { |ON target.col1 = source.col1""".stripMargin checkError( exception = parseException(sql), - condition = "_LEGACY_ERROR_TEMP_0008", + condition = "MERGE_WITHOUT_WHEN", parameters = Map.empty, context = ExpectedContext( fragment = sql, From 450891109f448a7f05390e6fd38abd5abff61cc9 Mon Sep 17 00:00:00 2001 From: "zhipeng.mao" Date: Sun, 20 Oct 2024 21:17:28 +0200 Subject: [PATCH 071/108] [SPARK-50027][SQL] Move Identity Column SQL parsing code to DataTypeAstBuilder ### What changes were proposed in this pull request? It moves code parsing Identity Column DDL from AstBuilder to DataTypeAstBuilder. ### Why are the changes needed? `DataTypeAstBuilder` is intended for parsing code of column definitions. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existent tests cover this. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48543 from zhipengmao-db/zhipengmao-db/id-column-refactor. Authored-by: zhipeng.mao Signed-off-by: Max Gekk --- .../catalyst/parser/DataTypeAstBuilder.scala | 57 +++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 54 +----------------- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 46fb4a3397c59..71e8517a4164e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -23,9 +23,11 @@ import scala.jdk.CollectionConverters._ import org.antlr.v4.runtime.Token import org.antlr.v4.runtime.tree.ParseTree +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} +import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} @@ -220,4 +222,59 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) { ctx.identifier.getText } + + /** + * Parse and verify IDENTITY column definition. + * + * @param ctx + * The parser context. + * @param dataType + * The data type of column defined as IDENTITY column. Used for verification. + * @return + * Tuple containing start, step and allowExplicitInsert. + */ + protected def visitIdentityColumn( + ctx: IdentityColumnContext, + dataType: DataType): IdentityColumnSpec = { + if (dataType != LongType && dataType != IntegerType) { + throw QueryParsingErrors.identityColumnUnsupportedDataType(ctx, dataType.toString) + } + // We support two flavors of syntax: + // (1) GENERATED ALWAYS AS IDENTITY (...) + // (2) GENERATED BY DEFAULT AS IDENTITY (...) + // (1) forbids explicit inserts, while (2) allows. + val allowExplicitInsert = ctx.BY() != null && ctx.DEFAULT() != null + val (start, step) = visitIdentityColSpec(ctx.identityColSpec()) + + new IdentityColumnSpec(start, step, allowExplicitInsert) + } + + override def visitIdentityColSpec(ctx: IdentityColSpecContext): (Long, Long) = { + val defaultStart = 1 + val defaultStep = 1 + if (ctx == null) { + return (defaultStart, defaultStep) + } + var (start, step): (Option[Long], Option[Long]) = (None, None) + ctx.sequenceGeneratorOption().asScala.foreach { option => + if (option.start != null) { + if (start.isDefined) { + throw QueryParsingErrors.identityColumnDuplicatedSequenceGeneratorOption(ctx, "START") + } + start = Some(option.start.getText.toLong) + } else if (option.step != null) { + if (step.isDefined) { + throw QueryParsingErrors.identityColumnDuplicatedSequenceGeneratorOption(ctx, "STEP") + } + step = Some(option.step.getText.toLong) + if (step.get == 0L) { + throw QueryParsingErrors.identityColumnIllegalStep(ctx) + } + } else { + throw SparkException + .internalError(s"Invalid identity column sequence generator option: ${option.getText}") + } + } + (start.getOrElse(defaultStart), step.getOrElse(defaultStep)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 25dd423791005..dde2ec6f5ae64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, IdentityColumnSpec, SupportsNamespaces, TableCatalog, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} @@ -3734,58 +3734,6 @@ class AstBuilder extends DataTypeAstBuilder getDefaultExpression(ctx.expression(), "GENERATED").originalSQL } - /** - * Parse and verify IDENTITY column definition. - * - * @param ctx The parser context. - * @param dataType The data type of column defined as IDENTITY column. Used for verification. - * @return Tuple containing start, step and allowExplicitInsert. - */ - protected def visitIdentityColumn( - ctx: IdentityColumnContext, - dataType: DataType): IdentityColumnSpec = { - if (dataType != LongType && dataType != IntegerType) { - throw QueryParsingErrors.identityColumnUnsupportedDataType(ctx, dataType.toString) - } - // We support two flavors of syntax: - // (1) GENERATED ALWAYS AS IDENTITY (...) - // (2) GENERATED BY DEFAULT AS IDENTITY (...) - // (1) forbids explicit inserts, while (2) allows. - val allowExplicitInsert = ctx.BY() != null && ctx.DEFAULT() != null - val (start, step) = visitIdentityColSpec(ctx.identityColSpec()) - - new IdentityColumnSpec(start, step, allowExplicitInsert) - } - - override def visitIdentityColSpec(ctx: IdentityColSpecContext): (Long, Long) = { - val defaultStart = 1 - val defaultStep = 1 - if (ctx == null) { - return (defaultStart, defaultStep) - } - var (start, step): (Option[Long], Option[Long]) = (None, None) - ctx.sequenceGeneratorOption().asScala.foreach { option => - if (option.start != null) { - if (start.isDefined) { - throw QueryParsingErrors.identityColumnDuplicatedSequenceGeneratorOption(ctx, "START") - } - start = Some(option.start.getText.toLong) - } else if (option.step != null) { - if (step.isDefined) { - throw QueryParsingErrors.identityColumnDuplicatedSequenceGeneratorOption(ctx, "STEP") - } - step = Some(option.step.getText.toLong) - if (step.get == 0L) { - throw QueryParsingErrors.identityColumnIllegalStep(ctx) - } - } else { - throw SparkException - .internalError(s"Invalid identity column sequence generator option: ${option.getText}") - } - } - (start.getOrElse(defaultStart), step.getOrElse(defaultStep)) - } - /** * Create an optional comment string. */ From 32cc2ddc8e9aa9eed59e15b58f2aff71bffec229 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 20 Oct 2024 21:19:43 +0200 Subject: [PATCH 072/108] [SPARK-50044][PYTHON] Refine the docstring of multiple math functions ### What changes were proposed in this pull request? Refine the docstring of multiple math functions ### Why are the changes needed? 1, make them copy-pasteable; 2, show the projection: input -> output ### Does this PR introduce _any_ user-facing change? doc changes ### How was this patch tested? updated doctests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48567 from zhengruifeng/doc_refine_ln. Authored-by: Ruifeng Zheng Signed-off-by: Max Gekk --- python/pyspark/sql/functions/builtin.py | 131 +++++++++++++++++------- 1 file changed, 94 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index dbc66cab3f9b3..65d8bfde1411f 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -7157,27 +7157,46 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non Examples -------- + Example 1: Specify both base number and the input value + >>> from pyspark.sql import functions as sf >>> df = spark.sql("SELECT * FROM VALUES (1), (2), (4) AS t(value)") - >>> df.select(sf.log(2.0, df.value).alias('log2_value')).show() - +----------+ - |log2_value| - +----------+ - | 0.0| - | 1.0| - | 2.0| - +----------+ + >>> df.select("*", sf.log(2.0, df.value)).show() + +-----+---------------+ + |value|LOG(2.0, value)| + +-----+---------------+ + | 1| 0.0| + | 2| 1.0| + | 4| 2.0| + +-----+---------------+ - And Natural logarithm + Example 2: Return NULL for invalid input values - >>> df.select(sf.log(df.value).alias('ln_value')).show() - +------------------+ - | ln_value| - +------------------+ - | 0.0| - |0.6931471805599453| - |1.3862943611198906| - +------------------+ + >>> from pyspark.sql import functions as sf + >>> df = spark.sql("SELECT * FROM VALUES (1), (2), (0), (-1), (NULL) AS t(value)") + >>> df.select("*", sf.log(3.0, df.value)).show() + +-----+------------------+ + |value| LOG(3.0, value)| + +-----+------------------+ + | 1| 0.0| + | 2|0.6309297535714...| + | 0| NULL| + | -1| NULL| + | NULL| NULL| + +-----+------------------+ + + Example 3: Specify only the input value (Natural logarithm) + + >>> from pyspark.sql import functions as sf + >>> df = spark.sql("SELECT * FROM VALUES (1), (2), (4) AS t(value)") + >>> df.select("*", sf.log(df.value)).show() + +-----+------------------+ + |value| ln(value)| + +-----+------------------+ + | 1| 0.0| + | 2|0.6931471805599...| + | 4|1.3862943611198...| + +-----+------------------+ """ from pyspark.sql.classic.column import _to_java_column @@ -7205,13 +7224,22 @@ def ln(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(ln('a')).show() - +------------------+ - | ln(a)| - +------------------+ - |1.3862943611198906| - +------------------+ + >>> from pyspark.sql import functions as sf + >>> spark.range(10).select("*", sf.ln('id')).show() + +---+------------------+ + | id| ln(id)| + +---+------------------+ + | 0| NULL| + | 1| 0.0| + | 2|0.6931471805599...| + | 3|1.0986122886681...| + | 4|1.3862943611198...| + | 5|1.6094379124341...| + | 6| 1.791759469228...| + | 7|1.9459101490553...| + | 8|2.0794415416798...| + | 9|2.1972245773362...| + +---+------------------+ """ return _invoke_function_over_columns("ln", col) @@ -7237,13 +7265,22 @@ def log2(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(log2('a').alias('log2')).show() - +----+ - |log2| - +----+ - | 2.0| - +----+ + >>> from pyspark.sql import functions as sf + >>> spark.range(10).select("*", sf.log2('id')).show() + +---+------------------+ + | id| LOG2(id)| + +---+------------------+ + | 0| NULL| + | 1| 0.0| + | 2| 1.0| + | 3| 1.584962500721...| + | 4| 2.0| + | 5| 2.321928094887...| + | 6| 2.584962500721...| + | 7| 2.807354922057...| + | 8| 3.0| + | 9|3.1699250014423...| + +---+------------------+ """ return _invoke_function_over_columns("log2", col) @@ -7274,9 +7311,16 @@ def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: Examples -------- - >>> df = spark.createDataFrame([("010101",)], ['n']) - >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() - [Row(hex='15')] + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([("010101",), ( "101",), ("001",)], ['n']) + >>> df.select("*", sf.conv(df.n, 2, 16)).show() + +------+--------------+ + | n|conv(n, 2, 16)| + +------+--------------+ + |010101| 15| + | 101| 5| + | 001| 1| + +------+--------------+ """ from pyspark.sql.classic.column import _to_java_column @@ -7307,9 +7351,22 @@ def factorial(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(5,)], ['n']) - >>> df.select(factorial(df.n).alias('f')).collect() - [Row(f=120)] + >>> from pyspark.sql import functions as sf + >>> spark.range(10).select("*", sf.factorial('id')).show() + +---+-------------+ + | id|factorial(id)| + +---+-------------+ + | 0| 1| + | 1| 1| + | 2| 2| + | 3| 6| + | 4| 24| + | 5| 120| + | 6| 720| + | 7| 5040| + | 8| 40320| + | 9| 362880| + +---+-------------+ """ return _invoke_function_over_columns("factorial", col) From d2e322314c786b892f4d8b37f383fae8e8827ca9 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 21 Oct 2024 11:57:30 +0800 Subject: [PATCH 073/108] [SPARK-50001][PYTHON][PS][CONNECT] Adjust "precision" to be part of kwargs for box plots MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Adjust "precision" to be kwargs for box plots in both Pandas on Spark and PySpark. ### Why are the changes needed? Per discussion here (https://github.com/apache/spark/pull/48445#discussion_r1804042377), precision is Spark-specific implementation detail, so we wanted to keep “precision” as part of kwargs for box plots. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48513 from xinrong-meng/precision. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- python/pyspark/pandas/plot/core.py | 15 +++++++-------- python/pyspark/sql/plot/core.py | 13 +++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 12c17a06f153b..f5652177fe4a5 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -841,7 +841,7 @@ def barh(self, x=None, y=None, **kwargs): elif isinstance(self.data, DataFrame): return self(kind="barh", x=x, y=y, **kwargs) - def box(self, precision=0.01, **kwds): + def box(self, **kwds): """ Make a box plot of the DataFrame columns. @@ -857,12 +857,11 @@ def box(self, precision=0.01, **kwds): Parameters ---------- - precision: scalar, default = 0.01 - This argument is used by pandas-on-Spark to compute approximate statistics - for building a boxplot. Use *smaller* values to get more precise - statistics. - **kwds : optional - Additional keyword arguments are documented in + **kwds : dict, optional + Extra arguments to `precision`: refer to a float that is used by + pandas-on-Spark to compute approximate statistics for building a + boxplot. The default value is 0.01. Use smaller values to get more + precise statistics. Additional keyword arguments are documented in :meth:`pyspark.pandas.Series.plot`. Returns @@ -901,7 +900,7 @@ def box(self, precision=0.01, **kwds): from pyspark.pandas import DataFrame, Series if isinstance(self.data, (Series, DataFrame)): - return self(kind="box", precision=precision, **kwds) + return self(kind="box", **kwds) def hist(self, bins=10, **kwds): """ diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index f44c0768d4337..178411e5c5ef8 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -359,9 +359,7 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": ) return self(kind="pie", x=x, y=y, **kwargs) - def box( - self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any - ) -> "Figure": + def box(self, column: Union[str, List[str]], **kwargs: Any) -> "Figure": """ Make a box plot of the DataFrame columns. @@ -377,11 +375,10 @@ def box( ---------- column: str or list of str Column name or list of names to be used for creating the boxplot. - precision: float, default = 0.01 - This argument is used by pyspark to compute approximate statistics - for building a boxplot. **kwargs - Additional keyword arguments. + Extra arguments to `precision`: refer to a float that is used by + pyspark to compute approximate statistics for building a boxplot. + The default value is 0.01. Use smaller values to get more precise statistics. Returns ------- @@ -404,7 +401,7 @@ def box( >>> df.plot.box(column="math_score") # doctest: +SKIP >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP """ - return self(kind="box", column=column, precision=precision, **kwargs) + return self(kind="box", column=column, **kwargs) def kde( self, From f581bad902586d3a28ac95e06ac7f46a3b527bd7 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Mon, 21 Oct 2024 12:52:36 +0800 Subject: [PATCH 074/108] [SPARK-50037][SQL] Refactor AttributeSeq.resolve(...) ### What changes were proposed in this pull request? Refactor `AttributeSeq.resolve(...)`: - Introduce scaladoc - Refactor into two methods ### Why are the changes needed? To make the development experience nicer. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? copilot.vim. Closes #48556 from vladimirg-db/vladimirg-db/refactoring-in-attribute-seq. Authored-by: Vladimir Golubev Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/package.scala | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 36fde4da2628b..20105b87004f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -331,8 +331,27 @@ package object expressions { (candidates, nestedFields) } - /** Perform attribute resolution given a name and a resolver. */ + /** + * Resolve `nameParts` into a specific [[NamedExpression]] using the provided `resolver`. + * + * This method finds all suitable candidates for the resolution based on the name matches and + * checks if the nested fields are requested. + * - If there's only one match and nested fields are requested, wrap the matched attribute with + * [[ExtractValue]], and recursively wrap that with additional [[ExtractValue]]s + * for each nested field. In the end, alias the final expression with the last nested field + * name. + * - If there's only one match and no nested fields are requested, return the matched attribute. + * - If there are no matches, return None. + * - If there is more than one match, throw [[QueryCompilationErrors.ambiguousReferenceError]]. + */ def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + val (candidates, nestedFields) = getCandidatesForResolution(nameParts, resolver) + resolveCandidates(nameParts, resolver, candidates, nestedFields) + } + + def getCandidatesForResolution( + nameParts: Seq[String], + resolver: Resolver): (Seq[Attribute], Seq[String]) = { val (candidates, nestedFields) = if (hasThreeOrLessQualifierParts) { matchWithThreeOrLessQualifierParts(nameParts, resolver) } else { @@ -345,13 +364,21 @@ package object expressions { candidates } + (prunedCandidates, nestedFields) + } + + def resolveCandidates( + nameParts: Seq[String], + resolver: Resolver, + candidates: Seq[Attribute], + nestedFields: Seq[String]): Option[NamedExpression] = { def name = UnresolvedAttribute(nameParts).name // We may have resolved the attributes from metadata columns. The resolved attributes will be // put in a logical plan node and becomes normal attributes. They can still keep the special // attribute metadata to indicate that they are from metadata columns, but they should not // keep any restrictions that may break column resolution for normal attributes. // See SPARK-42084 for more details. - prunedCandidates.distinct.map(_.markAsAllowAnyAccess()) match { + candidates.distinct.map(_.markAsAllowAnyAccess()) match { case Seq(a) if nestedFields.nonEmpty => // One match, but we also need to extract the requested nested field. // The foldLeft adds ExtractValues for every remaining parts of the identifier, From 537ca919c87a492bb69bebb5d3cbcc725f49dfe7 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 21 Oct 2024 14:32:26 +0900 Subject: [PATCH 075/108] [SPARK-49850][CONNECT][PYTHON] API compatibility check for Avro ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Avro functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48571 from itholic/compat_avro. Authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- .../sql/tests/test_connect_compatibility.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 72f139ac4768c..193de8e1b6b6a 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -32,6 +32,7 @@ from pyspark.sql.window import WindowSpec as ClassicWindowSpec import pyspark.sql.functions as ClassicFunctions from pyspark.sql.group import GroupedData as ClassicGroupedData +import pyspark.sql.avro.functions as ClassicAvro if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -45,6 +46,7 @@ from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec import pyspark.sql.connect.functions as ConnectFunctions from pyspark.sql.connect.group import GroupedData as ConnectGroupedData + import pyspark.sql.connect.avro.functions as ConnectAvro class ConnectCompatibilityTestsMixin: @@ -375,6 +377,28 @@ def test_grouping_compatibility(self): expected_missing_classic_methods, ) + def test_avro_compatibility(self): + """Test Avro compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + # The current supported Avro functions are only `from_avro` and `to_avro`. + # The missing methods belows are just util functions that imported to implement them. + expected_missing_connect_methods = { + "try_remote_avro_functions", + "cast", + "get_active_spark_context", + } + expected_missing_classic_methods = {"lit", "check_dependencies"} + self.check_compatibility( + ClassicAvro, + ConnectAvro, + "Avro", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From c1198fa2b45c7314712aa2601349e6666c37f71f Mon Sep 17 00:00:00 2001 From: bogao007 Date: Mon, 21 Oct 2024 15:57:54 +0900 Subject: [PATCH 076/108] [SPARK-49821][SS][PYTHON] Implement MapState and TTL support for TransformWithStateInPandas ### What changes were proposed in this pull request? - Implement MapState and TTL support for TransformWithStateInPandas - Fixed an issue to properly closes/cleans up resources after arrow batch writes are completed in `TransformWithStateInPandasStateServer`. Since we use the same arrow batch write logic for both listState and mapState, this fix also applies to listState. ### Why are the changes needed? Bring parity to Scala on supported state variables ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added new unit test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48290 from bogao007/map-state. Authored-by: bogao007 Signed-off-by: Jungtaek Lim --- .../pyspark/sql/streaming/StateMessage_pb2.py | 89 +- .../sql/streaming/StateMessage_pb2.pyi | 102 +- .../sql/streaming/list_state_client.py | 15 +- .../pyspark/sql/streaming/map_state_client.py | 309 + .../sql/streaming/stateful_processor.py | 112 + .../stateful_processor_api_client.py | 30 + .../test_pandas_transform_with_state.py | 95 +- .../execution/streaming/StateMessage.proto | 48 +- .../streaming/state/StateMessage.java | 8359 +++++++++++++++-- ...ransformWithStateInPandasStateServer.scala | 281 +- ...ormWithStateInPandasStateServerSuite.scala | 214 +- 11 files changed, 8751 insertions(+), 903 deletions(-) create mode 100644 python/pyspark/sql/streaming/map_state_client.py diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py b/python/pyspark/sql/streaming/StateMessage_pb2.py index e75d0394ea0f5..9c7740c0a223a 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/StateMessage_pb2.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE @@ -31,7 +30,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xd2\x01\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 + b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xa8\x02\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x12T\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"\x9a\x01\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x1b\n\x13mapStateValueSchema\x18\x03 \x01(\t\x12\x46\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\xe1\x05\n\x0cMapStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12L\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00\x12R\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00\x12R\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00\x12L\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00\x12\x44\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00\x12H\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00\x12N\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00\x12\x46\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\x1b\n\x08GetValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\x1e\n\x0b\x43ontainsKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"-\n\x0bUpdateValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c"\x1e\n\x08Iterator\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1a\n\x04Keys\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\x06Values\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\tRemoveKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 ) _globals = globals() @@ -39,8 +38,8 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_HANDLESTATE"]._serialized_start = 2694 - _globals["_HANDLESTATE"]._serialized_end = 2769 + _globals["_HANDLESTATE"]._serialized_start = 3778 + _globals["_HANDLESTATE"]._serialized_end = 3853 _globals["_STATEREQUEST"]._serialized_start = 71 _globals["_STATEREQUEST"]._serialized_end = 432 _globals["_STATERESPONSE"]._serialized_start = 434 @@ -48,37 +47,53 @@ _globals["_STATEFULPROCESSORCALL"]._serialized_start = 509 _globals["_STATEFULPROCESSORCALL"]._serialized_end = 902 _globals["_STATEVARIABLEREQUEST"]._serialized_start = 905 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1115 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1118 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1342 - _globals["_STATECALLCOMMAND"]._serialized_start = 1344 - _globals["_STATECALLCOMMAND"]._serialized_end = 1469 - _globals["_VALUESTATECALL"]._serialized_start = 1472 - _globals["_VALUESTATECALL"]._serialized_end = 1825 - _globals["_LISTSTATECALL"]._serialized_start = 1828 - _globals["_LISTSTATECALL"]._serialized_end = 2356 - _globals["_SETIMPLICITKEY"]._serialized_start = 2358 - _globals["_SETIMPLICITKEY"]._serialized_end = 2387 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 2389 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 2408 - _globals["_EXISTS"]._serialized_start = 2410 - _globals["_EXISTS"]._serialized_end = 2418 - _globals["_GET"]._serialized_start = 2420 - _globals["_GET"]._serialized_end = 2425 - _globals["_VALUESTATEUPDATE"]._serialized_start = 2427 - _globals["_VALUESTATEUPDATE"]._serialized_end = 2460 - _globals["_CLEAR"]._serialized_start = 2462 - _globals["_CLEAR"]._serialized_end = 2469 - _globals["_LISTSTATEGET"]._serialized_start = 2471 - _globals["_LISTSTATEGET"]._serialized_end = 2505 - _globals["_LISTSTATEPUT"]._serialized_start = 2507 - _globals["_LISTSTATEPUT"]._serialized_end = 2521 - _globals["_APPENDVALUE"]._serialized_start = 2523 - _globals["_APPENDVALUE"]._serialized_end = 2551 - _globals["_APPENDLIST"]._serialized_start = 2553 - _globals["_APPENDLIST"]._serialized_end = 2565 - _globals["_SETHANDLESTATE"]._serialized_start = 2567 - _globals["_SETHANDLESTATE"]._serialized_end = 2659 - _globals["_TTLCONFIG"]._serialized_start = 2661 - _globals["_TTLCONFIG"]._serialized_end = 2692 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1201 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1204 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1428 + _globals["_STATECALLCOMMAND"]._serialized_start = 1431 + _globals["_STATECALLCOMMAND"]._serialized_end = 1585 + _globals["_VALUESTATECALL"]._serialized_start = 1588 + _globals["_VALUESTATECALL"]._serialized_end = 1941 + _globals["_LISTSTATECALL"]._serialized_start = 1944 + _globals["_LISTSTATECALL"]._serialized_end = 2472 + _globals["_MAPSTATECALL"]._serialized_start = 2475 + _globals["_MAPSTATECALL"]._serialized_end = 3212 + _globals["_SETIMPLICITKEY"]._serialized_start = 3214 + _globals["_SETIMPLICITKEY"]._serialized_end = 3243 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 3245 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 3264 + _globals["_EXISTS"]._serialized_start = 3266 + _globals["_EXISTS"]._serialized_end = 3274 + _globals["_GET"]._serialized_start = 3276 + _globals["_GET"]._serialized_end = 3281 + _globals["_VALUESTATEUPDATE"]._serialized_start = 3283 + _globals["_VALUESTATEUPDATE"]._serialized_end = 3316 + _globals["_CLEAR"]._serialized_start = 3318 + _globals["_CLEAR"]._serialized_end = 3325 + _globals["_LISTSTATEGET"]._serialized_start = 3327 + _globals["_LISTSTATEGET"]._serialized_end = 3361 + _globals["_LISTSTATEPUT"]._serialized_start = 3363 + _globals["_LISTSTATEPUT"]._serialized_end = 3377 + _globals["_APPENDVALUE"]._serialized_start = 3379 + _globals["_APPENDVALUE"]._serialized_end = 3407 + _globals["_APPENDLIST"]._serialized_start = 3409 + _globals["_APPENDLIST"]._serialized_end = 3421 + _globals["_GETVALUE"]._serialized_start = 3423 + _globals["_GETVALUE"]._serialized_end = 3450 + _globals["_CONTAINSKEY"]._serialized_start = 3452 + _globals["_CONTAINSKEY"]._serialized_end = 3482 + _globals["_UPDATEVALUE"]._serialized_start = 3484 + _globals["_UPDATEVALUE"]._serialized_end = 3529 + _globals["_ITERATOR"]._serialized_start = 3531 + _globals["_ITERATOR"]._serialized_end = 3561 + _globals["_KEYS"]._serialized_start = 3563 + _globals["_KEYS"]._serialized_end = 3589 + _globals["_VALUES"]._serialized_start = 3591 + _globals["_VALUES"]._serialized_end = 3619 + _globals["_REMOVEKEY"]._serialized_start = 3621 + _globals["_REMOVEKEY"]._serialized_end = 3649 + _globals["_SETHANDLESTATE"]._serialized_start = 3651 + _globals["_SETHANDLESTATE"]._serialized_end = 3743 + _globals["_TTLCONFIG"]._serialized_start = 3745 + _globals["_TTLCONFIG"]._serialized_end = 3776 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/StateMessage_pb2.pyi index b1f5f0f7d2a1e..791a221d96f35 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/StateMessage_pb2.pyi @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message @@ -95,15 +94,18 @@ class StatefulProcessorCall(_message.Message): ) -> None: ... class StateVariableRequest(_message.Message): - __slots__ = ("valueStateCall", "listStateCall") + __slots__ = ("valueStateCall", "listStateCall", "mapStateCall") VALUESTATECALL_FIELD_NUMBER: _ClassVar[int] LISTSTATECALL_FIELD_NUMBER: _ClassVar[int] + MAPSTATECALL_FIELD_NUMBER: _ClassVar[int] valueStateCall: ValueStateCall listStateCall: ListStateCall + mapStateCall: MapStateCall def __init__( self, valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ..., listStateCall: _Optional[_Union[ListStateCall, _Mapping]] = ..., + mapStateCall: _Optional[_Union[MapStateCall, _Mapping]] = ..., ) -> None: ... class ImplicitGroupingKeyRequest(_message.Message): @@ -119,17 +121,20 @@ class ImplicitGroupingKeyRequest(_message.Message): ) -> None: ... class StateCallCommand(_message.Message): - __slots__ = ("stateName", "schema", "ttl") + __slots__ = ("stateName", "schema", "mapStateValueSchema", "ttl") STATENAME_FIELD_NUMBER: _ClassVar[int] SCHEMA_FIELD_NUMBER: _ClassVar[int] + MAPSTATEVALUESCHEMA_FIELD_NUMBER: _ClassVar[int] TTL_FIELD_NUMBER: _ClassVar[int] stateName: str schema: str + mapStateValueSchema: str ttl: TTLConfig def __init__( self, stateName: _Optional[str] = ..., schema: _Optional[str] = ..., + mapStateValueSchema: _Optional[str] = ..., ttl: _Optional[_Union[TTLConfig, _Mapping]] = ..., ) -> None: ... @@ -189,6 +194,53 @@ class ListStateCall(_message.Message): clear: _Optional[_Union[Clear, _Mapping]] = ..., ) -> None: ... +class MapStateCall(_message.Message): + __slots__ = ( + "stateName", + "exists", + "getValue", + "containsKey", + "updateValue", + "iterator", + "keys", + "values", + "removeKey", + "clear", + ) + STATENAME_FIELD_NUMBER: _ClassVar[int] + EXISTS_FIELD_NUMBER: _ClassVar[int] + GETVALUE_FIELD_NUMBER: _ClassVar[int] + CONTAINSKEY_FIELD_NUMBER: _ClassVar[int] + UPDATEVALUE_FIELD_NUMBER: _ClassVar[int] + ITERATOR_FIELD_NUMBER: _ClassVar[int] + KEYS_FIELD_NUMBER: _ClassVar[int] + VALUES_FIELD_NUMBER: _ClassVar[int] + REMOVEKEY_FIELD_NUMBER: _ClassVar[int] + CLEAR_FIELD_NUMBER: _ClassVar[int] + stateName: str + exists: Exists + getValue: GetValue + containsKey: ContainsKey + updateValue: UpdateValue + iterator: Iterator + keys: Keys + values: Values + removeKey: RemoveKey + clear: Clear + def __init__( + self, + stateName: _Optional[str] = ..., + exists: _Optional[_Union[Exists, _Mapping]] = ..., + getValue: _Optional[_Union[GetValue, _Mapping]] = ..., + containsKey: _Optional[_Union[ContainsKey, _Mapping]] = ..., + updateValue: _Optional[_Union[UpdateValue, _Mapping]] = ..., + iterator: _Optional[_Union[Iterator, _Mapping]] = ..., + keys: _Optional[_Union[Keys, _Mapping]] = ..., + values: _Optional[_Union[Values, _Mapping]] = ..., + removeKey: _Optional[_Union[RemoveKey, _Mapping]] = ..., + clear: _Optional[_Union[Clear, _Mapping]] = ..., + ) -> None: ... + class SetImplicitKey(_message.Message): __slots__ = ("key",) KEY_FIELD_NUMBER: _ClassVar[int] @@ -237,6 +289,50 @@ class AppendList(_message.Message): __slots__ = () def __init__(self) -> None: ... +class GetValue(_message.Message): + __slots__ = ("userKey",) + USERKEY_FIELD_NUMBER: _ClassVar[int] + userKey: bytes + def __init__(self, userKey: _Optional[bytes] = ...) -> None: ... + +class ContainsKey(_message.Message): + __slots__ = ("userKey",) + USERKEY_FIELD_NUMBER: _ClassVar[int] + userKey: bytes + def __init__(self, userKey: _Optional[bytes] = ...) -> None: ... + +class UpdateValue(_message.Message): + __slots__ = ("userKey", "value") + USERKEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + userKey: bytes + value: bytes + def __init__(self, userKey: _Optional[bytes] = ..., value: _Optional[bytes] = ...) -> None: ... + +class Iterator(_message.Message): + __slots__ = ("iteratorId",) + ITERATORID_FIELD_NUMBER: _ClassVar[int] + iteratorId: str + def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... + +class Keys(_message.Message): + __slots__ = ("iteratorId",) + ITERATORID_FIELD_NUMBER: _ClassVar[int] + iteratorId: str + def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... + +class Values(_message.Message): + __slots__ = ("iteratorId",) + ITERATORID_FIELD_NUMBER: _ClassVar[int] + iteratorId: str + def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... + +class RemoveKey(_message.Message): + __slots__ = ("userKey",) + USERKEY_FIELD_NUMBER: _ClassVar[int] + userKey: bytes + def __init__(self, userKey: _Optional[bytes] = ...) -> None: ... + class SetHandleState(_message.Message): __slots__ = ("state",) STATE_FIELD_NUMBER: _ClassVar[int] diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index 93306eca425eb..3615f6819bdd2 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -78,8 +78,19 @@ def get(self, state_name: str, iterator_id: str) -> Tuple: status = response_message[0] if status == 0: iterator = self._stateful_processor_api_client._read_arrow_state() - batch = next(iterator) - pandas_df = batch.to_pandas() + # We need to exhaust the iterator here to make sure all the arrow batches are read, + # even though there is only one batch in the iterator. Otherwise, the stream might + # block further reads since it thinks there might still be some arrow batches left. + # We only need to read the first batch in the iterator because it's guaranteed that + # there would only be one batch sent from the JVM side. + data_batch = None + for batch in iterator: + if data_batch is None: + data_batch = batch + if data_batch is None: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError("Error getting next list state row.") + pandas_df = data_batch.to_pandas() index = 0 else: raise StopIteration() diff --git a/python/pyspark/sql/streaming/map_state_client.py b/python/pyspark/sql/streaming/map_state_client.py new file mode 100644 index 0000000000000..54a5ba1bbffa0 --- /dev/null +++ b/python/pyspark/sql/streaming/map_state_client.py @@ -0,0 +1,309 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, Iterator, Union, cast, Tuple, Optional + +from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string +from pyspark.errors import PySparkRuntimeError +import uuid + +if TYPE_CHECKING: + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike + +__all__ = ["MapStateClient"] + + +class MapStateClient: + def __init__( + self, + stateful_processor_api_client: StatefulProcessorApiClient, + user_key_schema: Union[StructType, str], + value_schema: Union[StructType, str], + ) -> None: + self._stateful_processor_api_client = stateful_processor_api_client + if isinstance(user_key_schema, str): + self.user_key_schema = cast(StructType, _parse_datatype_string(user_key_schema)) + else: + self.user_key_schema = user_key_schema + if isinstance(value_schema, str): + self.value_schema = cast(StructType, _parse_datatype_string(value_schema)) + else: + self.value_schema = value_schema + # Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame + # and the index of the last row that was read. + self.user_key_value_pair_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + self.user_key_or_value_iterator_cursors: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + + def exists(self, state_name: str) -> bool: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + exists_call = stateMessage.Exists() + map_state_call = stateMessage.MapStateCall(stateName=state_name, exists=exists_call) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + return True + elif status == 2: + # Expect status code is 2 when state variable doesn't have a value. + return False + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error checking map state exists: {response_message[1]}") + + def get_value(self, state_name: str, user_key: Tuple) -> Optional[Tuple]: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + bytes = self._stateful_processor_api_client._serialize_to_bytes( + self.user_key_schema, user_key + ) + get_value_call = stateMessage.GetValue(userKey=bytes) + map_state_call = stateMessage.MapStateCall(stateName=state_name, getValue=get_value_call) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + if len(response_message[2]) == 0: + return None + row = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2]) + return row + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error getting value: {response_message[1]}") + + def contains_key(self, state_name: str, user_key: Tuple) -> bool: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + bytes = self._stateful_processor_api_client._serialize_to_bytes( + self.user_key_schema, user_key + ) + contains_key_call = stateMessage.ContainsKey(userKey=bytes) + map_state_call = stateMessage.MapStateCall( + stateName=state_name, containsKey=contains_key_call + ) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + return True + elif status == 2: + # Expect status code is 2 when the given key doesn't exist in the map state. + return False + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error checking if map state contains key: {response_message[1]}" + ) + + def update_value( + self, + state_name: str, + user_key: Tuple, + value: Tuple, + ) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + key_bytes = self._stateful_processor_api_client._serialize_to_bytes( + self.user_key_schema, user_key + ) + value_bytes = self._stateful_processor_api_client._serialize_to_bytes( + self.value_schema, value + ) + update_value_call = stateMessage.UpdateValue(userKey=key_bytes, value=value_bytes) + map_state_call = stateMessage.MapStateCall( + stateName=state_name, updateValue=update_value_call + ) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating map state value: {response_message[1]}") + + def get_key_value_pair(self, state_name: str, iterator_id: str) -> Tuple[Tuple, Tuple]: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if iterator_id in self.user_key_value_pair_iterator_cursors: + # If the state is already in the dictionary, return the next row. + pandas_df, index = self.user_key_value_pair_iterator_cursors[iterator_id] + else: + # If the state is not in the dictionary, fetch the state from the server. + iterator_call = stateMessage.Iterator(iteratorId=iterator_id) + map_state_call = stateMessage.MapStateCall(stateName=state_name, iterator=iterator_call) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + iterator = self._stateful_processor_api_client._read_arrow_state() + # We need to exhaust the iterator here to make sure all the arrow batches are read, + # even though there is only one batch in the iterator. Otherwise, the stream might + # block further reads since it thinks there might still be some arrow batches left. + # We only need to read the first batch in the iterator because it's guaranteed that + # there would only be one batch sent from the JVM side. + data_batch = None + for batch in iterator: + if data_batch is None: + data_batch = batch + if data_batch is None: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError("Error getting map state entry.") + pandas_df = data_batch.to_pandas() + index = 0 + else: + raise StopIteration() + + new_index = index + 1 + if new_index < len(pandas_df): + # Update the index in the dictionary. + self.user_key_value_pair_iterator_cursors[iterator_id] = (pandas_df, new_index) + else: + # If the index is at the end of the DataFrame, remove the state from the dictionary. + self.user_key_value_pair_iterator_cursors.pop(iterator_id, None) + key_row_bytes = pandas_df.iloc[index, 0] + value_row_bytes = pandas_df.iloc[index, 1] + key_row = self._stateful_processor_api_client._deserialize_from_bytes(key_row_bytes) + value_row = self._stateful_processor_api_client._deserialize_from_bytes(value_row_bytes) + return tuple(key_row), tuple(value_row) + + def get_row(self, state_name: str, iterator_id: str, is_key: bool) -> Tuple: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if iterator_id in self.user_key_or_value_iterator_cursors: + # If the state is already in the dictionary, return the next row. + pandas_df, index = self.user_key_or_value_iterator_cursors[iterator_id] + else: + # If the state is not in the dictionary, fetch the state from the server. + if is_key: + keys_call = stateMessage.Keys(iteratorId=iterator_id) + map_state_call = stateMessage.MapStateCall(stateName=state_name, keys=keys_call) + else: + values_call = stateMessage.Values(iteratorId=iterator_id) + map_state_call = stateMessage.MapStateCall(stateName=state_name, values=values_call) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + iterator = self._stateful_processor_api_client._read_arrow_state() + # We need to exhaust the iterator here to make sure all the arrow batches are read, + # even though there is only one batch in the iterator. Otherwise, the stream might + # block further reads since it thinks there might still be some arrow batches left. + # We only need to read the first batch in the iterator because it's guaranteed that + # there would only be one batch sent from the JVM side. + data_batch = None + for batch in iterator: + if data_batch is None: + data_batch = batch + if data_batch is None: + entry_name = "key" + if not is_key: + entry_name = "value" + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error getting map state {entry_name}.") + pandas_df = data_batch.to_pandas() + index = 0 + else: + raise StopIteration() + + new_index = index + 1 + if new_index < len(pandas_df): + # Update the index in the dictionary. + self.user_key_or_value_iterator_cursors[iterator_id] = (pandas_df, new_index) + else: + # If the index is at the end of the DataFrame, remove the state from the dictionary. + self.user_key_or_value_iterator_cursors.pop(iterator_id, None) + pandas_row = pandas_df.iloc[index] + return tuple(pandas_row) + + def remove_key(self, state_name: str, key: Tuple) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + bytes = self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, key) + remove_key_call = stateMessage.RemoveKey(userKey=bytes) + map_state_call = stateMessage.MapStateCall(stateName=state_name, removeKey=remove_key_call) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error removing key from map state: {response_message[1]}") + + def clear(self, state_name: str) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + clear_call = stateMessage.Clear() + map_state_call = stateMessage.MapStateCall(stateName=state_name, clear=clear_call) + state_variable_request = stateMessage.StateVariableRequest(mapStateCall=map_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error clearing map state: " f"{response_message[1]}") + + +class MapStateIterator: + def __init__(self, map_state_client: MapStateClient, state_name: str, is_key: bool): + self.map_state_client = map_state_client + self.state_name = state_name + # Generate a unique identifier for the iterator to make sure iterators from the same + # map state do not interfere with each other. + self.iterator_id = str(uuid.uuid4()) + self.is_key = is_key + + def __iter__(self) -> Iterator[Tuple]: + return self + + def __next__(self) -> Tuple: + return self.map_state_client.get_row(self.state_name, self.iterator_id, self.is_key) + + +class MapStateKeyValuePairIterator: + def __init__(self, map_state_client: MapStateClient, state_name: str): + self.map_state_client = map_state_client + self.state_name = state_name + # Generate a unique identifier for the iterator to make sure iterators from the same + # map state do not interfere with each other. + self.iterator_id = str(uuid.uuid4()) + + def __iter__(self) -> Iterator[Tuple[Tuple, Tuple]]: + return self + + def __next__(self) -> Tuple[Tuple, Tuple]: + return self.map_state_client.get_key_value_pair(self.state_name, self.iterator_id) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 6b8de0f8ac4ec..bac762e8addd8 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -20,6 +20,11 @@ from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient from pyspark.sql.streaming.list_state_client import ListStateClient, ListStateIterator +from pyspark.sql.streaming.map_state_client import ( + MapStateClient, + MapStateIterator, + MapStateKeyValuePairIterator, +) from pyspark.sql.streaming.value_state_client import ValueStateClient from pyspark.sql.types import StructType @@ -121,6 +126,77 @@ def clear(self) -> None: self._list_state_client.clear(self._state_name) +class MapState: + """ + Class used for arbitrary stateful operations with transformWithState to capture single map + state. + + .. versionadded:: 4.0.0 + """ + + def __init__( + self, + map_state_client: MapStateClient, + state_name: str, + ) -> None: + self._map_state_client = map_state_client + self._state_name = state_name + + def exists(self) -> bool: + """ + Whether state exists or not. + """ + return self._map_state_client.exists(self._state_name) + + def get_value(self, key: Tuple) -> Optional[Tuple]: + """ + Get the state value for given user key if it exists. + """ + return self._map_state_client.get_value(self._state_name, key) + + def contains_key(self, key: Tuple) -> bool: + """ + Check if the user key is contained in the map. + """ + return self._map_state_client.contains_key(self._state_name, key) + + def update_value(self, key: Tuple, value: Tuple) -> None: + """ + Update value for given user key. + """ + return self._map_state_client.update_value(self._state_name, key, value) + + def iterator(self) -> Iterator[Tuple[Tuple, Tuple]]: + """ + Get the map associated with grouping key. + """ + return MapStateKeyValuePairIterator(self._map_state_client, self._state_name) + + def keys(self) -> Iterator[Tuple]: + """ + Get the list of keys present in map associated with grouping key. + """ + return MapStateIterator(self._map_state_client, self._state_name, True) + + def values(self) -> Iterator[Tuple]: + """ + Get the list of values present in map associated with grouping key. + """ + return MapStateIterator(self._map_state_client, self._state_name, False) + + def remove_key(self, key: Tuple) -> None: + """ + Remove user key from map state. + """ + return self._map_state_client.remove_key(self._state_name, key) + + def clear(self) -> None: + """ + Remove this state. + """ + self._map_state_client.clear(self._state_name) + + class StatefulProcessorHandle: """ Represents the operation handle provided to the stateful processor used in transformWithState @@ -180,6 +256,42 @@ def getListState( self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms) return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema) + def getMapState( + self, + state_name: str, + user_key_schema: Union[StructType, str], + value_schema: Union[StructType, str], + ttl_duration_ms: Optional[int] = None, + ) -> MapState: + """ + Function to create new or return existing single map state variable of given type. + The user must ensure to call this function only within the `init()` method of the + :class:`StatefulProcessor`. + + Parameters + ---------- + state_name : str + name of the state variable + user_key_schema : :class:`pyspark.sql.types.DataType` or str + The schema of the key of map state. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + value_schema : :class:`pyspark.sql.types.DataType` or str + The schema of the value of map state The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + ttl_duration_ms: int + Time to live duration of the state in milliseconds. State values will not be returned + past ttlDuration and will be eventually removed from the state store. Any state update + resets the expiration time to current processing time plus ttlDuration. + If ttl is not specified the state will never expire. + """ + self.stateful_processor_api_client.get_map_state( + state_name, user_key_schema, value_schema, ttl_duration_ms + ) + return MapState( + MapStateClient(self.stateful_processor_api_client, user_key_schema, value_schema), + state_name, + ) + class StatefulProcessor(ABC): """ diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 449d5a2ad55dc..552ab44d1ddf4 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -154,6 +154,36 @@ def get_list_state( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + def get_map_state( + self, + state_name: str, + user_key_schema: Union[StructType, str], + value_schema: Union[StructType, str], + ttl_duration_ms: Optional[int], + ) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(user_key_schema, str): + user_key_schema = cast(StructType, _parse_datatype_string(user_key_schema)) + if isinstance(value_schema, str): + value_schema = cast(StructType, _parse_datatype_string(value_schema)) + + state_call_command = stateMessage.StateCallCommand() + state_call_command.stateName = state_name + state_call_command.schema = user_key_schema.json() + state_call_command.mapStateValueSchema = value_schema.json() + if ttl_duration_ms is not None: + state_call_command.ttl.durationMs = ttl_duration_ms + call = stateMessage.StatefulProcessorCall(getMapState=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error initializing map state: " f"{response_message[1]}") + def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message # format or even changing the message protocol in the future. diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 01cd441941d93..7339897cb2cc2 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -233,6 +233,27 @@ def check_results(batch_df, _): ListStateLargeTTLProcessor(), check_results, True, "processingTime" ) + def test_transform_with_state_in_pandas_map_state(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic(MapStateProcessor(), check_results, True) + + # test map state with ttl has the same behavior as map state when state doesn't expire. + def test_transform_with_state_in_pandas_map_state_large_ttl(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic( + MapStateLargeTTLProcessor(), check_results, True, "processingTime" + ) + # test value state with ttl has the same behavior as value state when # state doesn't expire. def test_value_state_ttl_basic(self): @@ -261,9 +282,11 @@ def check_results(batch_df, batch_id): Row(id="ttl-count-0", count=1), Row(id="count-0", count=1), Row(id="ttl-list-state-count-0", count=1), + Row(id="ttl-map-state-count-0", count=1), Row(id="ttl-count-1", count=1), Row(id="count-1", count=1), Row(id="ttl-list-state-count-1", count=1), + Row(id="ttl-map-state-count-1", count=1), ], ) elif batch_id == 1: @@ -273,9 +296,11 @@ def check_results(batch_df, batch_id): Row(id="ttl-count-0", count=2), Row(id="count-0", count=2), Row(id="ttl-list-state-count-0", count=3), + Row(id="ttl-map-state-count-0", count=2), Row(id="ttl-count-1", count=2), Row(id="count-1", count=2), Row(id="ttl-list-state-count-1", count=3), + Row(id="ttl-map-state-count-1", count=2), ], ) elif batch_id == 2: @@ -292,9 +317,11 @@ def check_results(batch_df, batch_id): Row(id="ttl-count-0", count=1), Row(id="count-0", count=3), Row(id="ttl-list-state-count-0", count=1), + Row(id="ttl-map-state-count-0", count=1), Row(id="ttl-count-1", count=3), Row(id="count-1", count=3), Row(id="ttl-list-state-count-1", count=7), + Row(id="ttl-map-state-count-1", count=3), ], ) if batch_id == 0 or batch_id == 1: @@ -382,14 +409,19 @@ def init(self, handle: StatefulProcessorHandle) -> None: class TTLStatefulProcessor(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) + user_key_schema = StructType([StructField("id", StringType(), True)]) self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000) self.count_state = handle.getValueState("state", state_schema) self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000) + self.ttl_map_state = handle.getMapState( + "ttl-map-state", user_key_schema, state_schema, 10000 + ) def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 ttl_list_state_count = 0 + ttl_map_state_count = 0 id = key[0] if self.count_state.exists(): count = self.count_state.get()[0] @@ -399,21 +431,30 @@ def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: iter = self.ttl_list_state.get() for s in iter: ttl_list_state_count += s[0] + if self.ttl_map_state.exists(): + ttl_map_state_count = self.ttl_map_state.get_value(key)[0] for pdf in rows: pdf_count = pdf.count().get("temperature") count += pdf_count ttl_count += pdf_count ttl_list_state_count += pdf_count + ttl_map_state_count += pdf_count self.count_state.update((count,)) # skip updating state for the 2nd batch so that ttl state expire if not (ttl_count == 2 and id == "0"): self.ttl_count_state.update((ttl_count,)) self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)]) + self.ttl_map_state.update_value(key, (ttl_map_state_count,)) yield pd.DataFrame( { - "id": [f"ttl-count-{id}", f"count-{id}", f"ttl-list-state-count-{id}"], - "count": [ttl_count, count, ttl_list_state_count], + "id": [ + f"ttl-count-{id}", + f"count-{id}", + f"ttl-list-state-count-{id}", + f"ttl-map-state-count-{id}", + ], + "count": [ttl_count, count, ttl_list_state_count, ttl_map_state_count], } ) @@ -492,8 +533,6 @@ def close(self) -> None: pass -# A stateful processor that inherit all behavior of ListStateProcessor except that it use -# ttl state with a large timeout. class ListStateLargeTTLProcessor(ListStateProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("temperature", IntegerType(), True)]) @@ -501,6 +540,54 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.list_state2 = handle.getListState("listState2", state_schema, 30000) +class MapStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle): + key_schema = StructType([StructField("name", StringType(), True)]) + value_schema = StructType([StructField("count", IntegerType(), True)]) + self.map_state = handle.getMapState("mapState", key_schema, value_schema) + + def handleInputRows(self, key, rows): + count = 0 + key1 = ("key1",) + key2 = ("key2",) + for pdf in rows: + pdf_count = pdf.count() + count += pdf_count.get("temperature") + value1 = count + value2 = count + if self.map_state.exists(): + if self.map_state.contains_key(key1): + value1 += self.map_state.get_value(key1)[0] + if self.map_state.contains_key(key2): + value2 += self.map_state.get_value(key2)[0] + self.map_state.update_value(key1, (value1,)) + self.map_state.update_value(key2, (value2,)) + key_iter = self.map_state.keys() + assert next(key_iter)[0] == "key1" + assert next(key_iter)[0] == "key2" + value_iter = self.map_state.values() + assert next(value_iter)[0] == value1 + assert next(value_iter)[0] == value2 + map_iter = self.map_state.iterator() + assert next(map_iter)[0] == key1 + assert next(map_iter)[1] == (value2,) + self.map_state.remove_key(key1) + assert not self.map_state.contains_key(key1) + yield pd.DataFrame({"id": key, "countAsString": str(count)}) + + def close(self) -> None: + pass + + +# A stateful processor that inherit all behavior of MapStateProcessor except that it use +# ttl state with a large timeout. +class MapStateLargeTTLProcessor(MapStateProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + key_schema = StructType([StructField("name", StringType(), True)]) + value_schema = StructType([StructField("count", IntegerType(), True)]) + self.map_state = handle.getMapState("mapState", key_schema, value_schema, 30000) + + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto index 63728216ded1e..bb1c4c4f8e6ca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -47,6 +47,7 @@ message StateVariableRequest { oneof method { ValueStateCall valueStateCall = 1; ListStateCall listStateCall = 2; + MapStateCall mapStateCall = 3; } } @@ -60,7 +61,8 @@ message ImplicitGroupingKeyRequest { message StateCallCommand { string stateName = 1; string schema = 2; - TTLConfig ttl = 3; + string mapStateValueSchema = 3; + TTLConfig ttl = 4; } message ValueStateCall { @@ -85,6 +87,21 @@ message ListStateCall { } } +message MapStateCall { + string stateName = 1; + oneof method { + Exists exists = 2; + GetValue getValue = 3; + ContainsKey containsKey = 4; + UpdateValue updateValue = 5; + Iterator iterator = 6; + Keys keys = 7; + Values values = 8; + RemoveKey removeKey = 9; + Clear clear = 10; + } +} + message SetImplicitKey { bytes key = 1; } @@ -119,6 +136,35 @@ message AppendValue { message AppendList { } +message GetValue { + bytes userKey = 1; +} + +message ContainsKey { + bytes userKey = 1; +} + +message UpdateValue { + bytes userKey = 1; + bytes value = 2; +} + +message Iterator { + string iteratorId = 1; +} + +message Keys { + string iteratorId = 1; +} + +message Values { + string iteratorId = 1; +} + +message RemoveKey { + bytes userKey = 1; +} + enum HandleState { CREATED = 0; INITIALIZED = 1; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java index d6d56dd732775..852f820173f45 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java @@ -3477,6 +3477,21 @@ public interface StateVariableRequestOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder(); + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + * @return Whether the mapStateCall field is set. + */ + boolean hasMapStateCall(); + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + * @return The mapStateCall. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall getMapStateCall(); + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder getMapStateCallOrBuilder(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase(); } /** @@ -3526,6 +3541,7 @@ public enum MethodCase com.google.protobuf.AbstractMessage.InternalOneOfEnum { VALUESTATECALL(1), LISTSTATECALL(2), + MAPSTATECALL(3), METHOD_NOT_SET(0); private final int value; private MethodCase(int value) { @@ -3545,6 +3561,7 @@ public static MethodCase forNumber(int value) { switch (value) { case 1: return VALUESTATECALL; case 2: return LISTSTATECALL; + case 3: return MAPSTATECALL; case 0: return METHOD_NOT_SET; default: return null; } @@ -3622,6 +3639,37 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); } + public static final int MAPSTATECALL_FIELD_NUMBER = 3; + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + * @return Whether the mapStateCall field is set. + */ + @java.lang.Override + public boolean hasMapStateCall() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + * @return The mapStateCall. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall getMapStateCall() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder getMapStateCallOrBuilder() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -3642,6 +3690,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (methodCase_ == 2) { output.writeMessage(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_); } + if (methodCase_ == 3) { + output.writeMessage(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_); + } getUnknownFields().writeTo(output); } @@ -3659,6 +3710,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_); } + if (methodCase_ == 3) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -3684,6 +3739,10 @@ public boolean equals(final java.lang.Object obj) { if (!getListStateCall() .equals(other.getListStateCall())) return false; break; + case 3: + if (!getMapStateCall() + .equals(other.getMapStateCall())) return false; + break; case 0: default: } @@ -3707,6 +3766,10 @@ public int hashCode() { hash = (37 * hash) + LISTSTATECALL_FIELD_NUMBER; hash = (53 * hash) + getListStateCall().hashCode(); break; + case 3: + hash = (37 * hash) + MAPSTATECALL_FIELD_NUMBER; + hash = (53 * hash) + getMapStateCall().hashCode(); + break; case 0: default: } @@ -3844,6 +3907,9 @@ public Builder clear() { if (listStateCallBuilder_ != null) { listStateCallBuilder_.clear(); } + if (mapStateCallBuilder_ != null) { + mapStateCallBuilder_.clear(); + } methodCase_ = 0; method_ = null; return this; @@ -3886,6 +3952,13 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable result.method_ = listStateCallBuilder_.build(); } } + if (methodCase_ == 3) { + if (mapStateCallBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = mapStateCallBuilder_.build(); + } + } result.methodCase_ = methodCase_; onBuilt(); return result; @@ -3944,6 +4017,10 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes mergeListStateCall(other.getListStateCall()); break; } + case MAPSTATECALL: { + mergeMapStateCall(other.getMapStateCall()); + break; + } case METHOD_NOT_SET: { break; } @@ -3988,6 +4065,13 @@ public Builder mergeFrom( methodCase_ = 2; break; } // case 18 + case 26: { + input.readMessage( + getMapStateCallFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 3; + break; + } // case 26 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -4302,6 +4386,148 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall onChanged();; return listStateCallBuilder_; } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder> mapStateCallBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + * @return Whether the mapStateCall field is set. + */ + @java.lang.Override + public boolean hasMapStateCall() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + * @return The mapStateCall. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall getMapStateCall() { + if (mapStateCallBuilder_ == null) { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } else { + if (methodCase_ == 3) { + return mapStateCallBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + public Builder setMapStateCall(org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall value) { + if (mapStateCallBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + mapStateCallBuilder_.setMessage(value); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + public Builder setMapStateCall( + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder builderForValue) { + if (mapStateCallBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + mapStateCallBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + public Builder mergeMapStateCall(org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall value) { + if (mapStateCallBuilder_ == null) { + if (methodCase_ == 3 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 3) { + mapStateCallBuilder_.mergeFrom(value); + } else { + mapStateCallBuilder_.setMessage(value); + } + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + public Builder clearMapStateCall() { + if (mapStateCallBuilder_ == null) { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + } + mapStateCallBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder getMapStateCallBuilder() { + return getMapStateCallFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder getMapStateCallOrBuilder() { + if ((methodCase_ == 3) && (mapStateCallBuilder_ != null)) { + return mapStateCallBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.MapStateCall mapStateCall = 3; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder> + getMapStateCallFieldBuilder() { + if (mapStateCallBuilder_ == null) { + if (!(methodCase_ == 3)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } + mapStateCallBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 3; + onChanged();; + return mapStateCallBuilder_; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -5318,17 +5544,29 @@ public interface StateCallCommandOrBuilder extends getSchemaBytes(); /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * string mapStateValueSchema = 3; + * @return The mapStateValueSchema. + */ + java.lang.String getMapStateValueSchema(); + /** + * string mapStateValueSchema = 3; + * @return The bytes for mapStateValueSchema. + */ + com.google.protobuf.ByteString + getMapStateValueSchemaBytes(); + + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; * @return Whether the ttl field is set. */ boolean hasTtl(); /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; * @return The ttl. */ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl(); /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder(); } @@ -5347,6 +5585,7 @@ private StateCallCommand(com.google.protobuf.GeneratedMessageV3.Builder build private StateCallCommand() { stateName_ = ""; schema_ = ""; + mapStateValueSchema_ = ""; } @java.lang.Override @@ -5450,10 +5689,48 @@ public java.lang.String getSchema() { } } - public static final int TTL_FIELD_NUMBER = 3; + public static final int MAPSTATEVALUESCHEMA_FIELD_NUMBER = 3; + private volatile java.lang.Object mapStateValueSchema_; + /** + * string mapStateValueSchema = 3; + * @return The mapStateValueSchema. + */ + @java.lang.Override + public java.lang.String getMapStateValueSchema() { + java.lang.Object ref = mapStateValueSchema_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + mapStateValueSchema_ = s; + return s; + } + } + /** + * string mapStateValueSchema = 3; + * @return The bytes for mapStateValueSchema. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getMapStateValueSchemaBytes() { + java.lang.Object ref = mapStateValueSchema_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + mapStateValueSchema_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int TTL_FIELD_NUMBER = 4; private org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig ttl_; /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; * @return Whether the ttl field is set. */ @java.lang.Override @@ -5461,7 +5738,7 @@ public boolean hasTtl() { return ttl_ != null; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; * @return The ttl. */ @java.lang.Override @@ -5469,7 +5746,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get return ttl_ == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder() { @@ -5496,8 +5773,11 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(schema_)) { com.google.protobuf.GeneratedMessageV3.writeString(output, 2, schema_); } + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(mapStateValueSchema_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 3, mapStateValueSchema_); + } if (ttl_ != null) { - output.writeMessage(3, getTtl()); + output.writeMessage(4, getTtl()); } getUnknownFields().writeTo(output); } @@ -5514,9 +5794,12 @@ public int getSerializedSize() { if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(schema_)) { size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, schema_); } + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(mapStateValueSchema_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, mapStateValueSchema_); + } if (ttl_ != null) { size += com.google.protobuf.CodedOutputStream - .computeMessageSize(3, getTtl()); + .computeMessageSize(4, getTtl()); } size += getUnknownFields().getSerializedSize(); memoizedSize = size; @@ -5537,6 +5820,8 @@ public boolean equals(final java.lang.Object obj) { .equals(other.getStateName())) return false; if (!getSchema() .equals(other.getSchema())) return false; + if (!getMapStateValueSchema() + .equals(other.getMapStateValueSchema())) return false; if (hasTtl() != other.hasTtl()) return false; if (hasTtl()) { if (!getTtl() @@ -5557,6 +5842,8 @@ public int hashCode() { hash = (53 * hash) + getStateName().hashCode(); hash = (37 * hash) + SCHEMA_FIELD_NUMBER; hash = (53 * hash) + getSchema().hashCode(); + hash = (37 * hash) + MAPSTATEVALUESCHEMA_FIELD_NUMBER; + hash = (53 * hash) + getMapStateValueSchema().hashCode(); if (hasTtl()) { hash = (37 * hash) + TTL_FIELD_NUMBER; hash = (53 * hash) + getTtl().hashCode(); @@ -5693,6 +5980,8 @@ public Builder clear() { schema_ = ""; + mapStateValueSchema_ = ""; + if (ttlBuilder_ == null) { ttl_ = null; } else { @@ -5727,6 +6016,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand(this); result.stateName_ = stateName_; result.schema_ = schema_; + result.mapStateValueSchema_ = mapStateValueSchema_; if (ttlBuilder_ == null) { result.ttl_ = ttl_; } else { @@ -5788,6 +6078,10 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes schema_ = other.schema_; onChanged(); } + if (!other.getMapStateValueSchema().isEmpty()) { + mapStateValueSchema_ = other.mapStateValueSchema_; + onChanged(); + } if (other.hasTtl()) { mergeTtl(other.getTtl()); } @@ -5828,12 +6122,17 @@ public Builder mergeFrom( break; } // case 18 case 26: { + mapStateValueSchema_ = input.readStringRequireUtf8(); + + break; + } // case 26 + case 34: { input.readMessage( getTtlFieldBuilder().getBuilder(), extensionRegistry); break; - } // case 26 + } // case 34 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -6002,18 +6301,94 @@ public Builder setSchemaBytes( return this; } + private java.lang.Object mapStateValueSchema_ = ""; + /** + * string mapStateValueSchema = 3; + * @return The mapStateValueSchema. + */ + public java.lang.String getMapStateValueSchema() { + java.lang.Object ref = mapStateValueSchema_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + mapStateValueSchema_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string mapStateValueSchema = 3; + * @return The bytes for mapStateValueSchema. + */ + public com.google.protobuf.ByteString + getMapStateValueSchemaBytes() { + java.lang.Object ref = mapStateValueSchema_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + mapStateValueSchema_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string mapStateValueSchema = 3; + * @param value The mapStateValueSchema to set. + * @return This builder for chaining. + */ + public Builder setMapStateValueSchema( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + mapStateValueSchema_ = value; + onChanged(); + return this; + } + /** + * string mapStateValueSchema = 3; + * @return This builder for chaining. + */ + public Builder clearMapStateValueSchema() { + + mapStateValueSchema_ = getDefaultInstance().getMapStateValueSchema(); + onChanged(); + return this; + } + /** + * string mapStateValueSchema = 3; + * @param value The bytes for mapStateValueSchema to set. + * @return This builder for chaining. + */ + public Builder setMapStateValueSchemaBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + mapStateValueSchema_ = value; + onChanged(); + return this; + } + private org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig ttl_; private com.google.protobuf.SingleFieldBuilderV3< org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder> ttlBuilder_; /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; * @return Whether the ttl field is set. */ public boolean hasTtl() { return ttlBuilder_ != null || ttl_ != null; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; * @return The ttl. */ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl() { @@ -6024,7 +6399,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get } } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ public Builder setTtl(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig value) { if (ttlBuilder_ == null) { @@ -6040,7 +6415,7 @@ public Builder setTtl(org.apache.spark.sql.execution.streaming.state.StateMessag return this; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ public Builder setTtl( org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder builderForValue) { @@ -6054,7 +6429,7 @@ public Builder setTtl( return this; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ public Builder mergeTtl(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig value) { if (ttlBuilder_ == null) { @@ -6072,7 +6447,7 @@ public Builder mergeTtl(org.apache.spark.sql.execution.streaming.state.StateMess return this; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ public Builder clearTtl() { if (ttlBuilder_ == null) { @@ -6086,7 +6461,7 @@ public Builder clearTtl() { return this; } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder getTtlBuilder() { @@ -6094,7 +6469,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Bui return getTtlFieldBuilder().getBuilder(); } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder() { if (ttlBuilder_ != null) { @@ -6105,7 +6480,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBu } } /** - * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 4; */ private com.google.protobuf.SingleFieldBuilderV3< org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder> @@ -9684,37 +10059,5781 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall } - public interface SetImplicitKeyOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + public interface MapStateCallOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.MapStateCall) com.google.protobuf.MessageOrBuilder { /** - * bytes key = 1; - * @return The key. + * string stateName = 1; + * @return The stateName. */ - com.google.protobuf.ByteString getKey(); - } - /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} - */ - public static final class SetImplicitKey extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - SetImplicitKeyOrBuilder { - private static final long serialVersionUID = 0L; - // Use SetImplicitKey.newBuilder() to construct. - private SetImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); + java.lang.String getStateName(); + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + com.google.protobuf.ByteString + getStateNameBytes(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + boolean hasExists(); + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists(); + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + * @return Whether the getValue field is set. + */ + boolean hasGetValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + * @return The getValue. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue getGetValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder getGetValueOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + * @return Whether the containsKey field is set. + */ + boolean hasContainsKey(); + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + * @return The containsKey. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey getContainsKey(); + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder getContainsKeyOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + * @return Whether the updateValue field is set. + */ + boolean hasUpdateValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + * @return The updateValue. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue getUpdateValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder getUpdateValueOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + * @return Whether the iterator field is set. + */ + boolean hasIterator(); + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + * @return The iterator. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator getIterator(); + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder getIteratorOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + * @return Whether the keys field is set. + */ + boolean hasKeys(); + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + * @return The keys. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys getKeys(); + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder getKeysOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + * @return Whether the values field is set. + */ + boolean hasValues(); + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + * @return The values. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Values getValues(); + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder getValuesOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + * @return Whether the removeKey field is set. + */ + boolean hasRemoveKey(); + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + * @return The removeKey. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey getRemoveKey(); + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder getRemoveKeyOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + * @return Whether the clear field is set. + */ + boolean hasClear(); + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + * @return The clear. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear(); + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder(); + + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.MethodCase getMethodCase(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.MapStateCall} + */ + public static final class MapStateCall extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.MapStateCall) + MapStateCallOrBuilder { + private static final long serialVersionUID = 0L; + // Use MapStateCall.newBuilder() to construct. + private MapStateCall(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private MapStateCall() { + stateName_ = ""; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new MapStateCall(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.class, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder.class); + } + + private int methodCase_ = 0; + private java.lang.Object method_; + public enum MethodCase + implements com.google.protobuf.Internal.EnumLite, + com.google.protobuf.AbstractMessage.InternalOneOfEnum { + EXISTS(2), + GETVALUE(3), + CONTAINSKEY(4), + UPDATEVALUE(5), + ITERATOR(6), + KEYS(7), + VALUES(8), + REMOVEKEY(9), + CLEAR(10), + METHOD_NOT_SET(0); + private final int value; + private MethodCase(int value) { + this.value = value; + } + /** + * @param value The number of the enum to look for. + * @return The enum associated with the given number. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static MethodCase valueOf(int value) { + return forNumber(value); + } + + public static MethodCase forNumber(int value) { + switch (value) { + case 2: return EXISTS; + case 3: return GETVALUE; + case 4: return CONTAINSKEY; + case 5: return UPDATEVALUE; + case 6: return ITERATOR; + case 7: return KEYS; + case 8: return VALUES; + case 9: return REMOVEKEY; + case 10: return CLEAR; + case 0: return METHOD_NOT_SET; + default: return null; + } + } + public int getNumber() { + return this.value; + } + }; + + public MethodCase + getMethodCase() { + return MethodCase.forNumber( + methodCase_); + } + + public static final int STATENAME_FIELD_NUMBER = 1; + private volatile java.lang.Object stateName_; + /** + * string stateName = 1; + * @return The stateName. + */ + @java.lang.Override + public java.lang.String getStateName() { + java.lang.Object ref = stateName_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + stateName_ = s; + return s; + } + } + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getStateNameBytes() { + java.lang.Object ref = stateName_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + stateName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int EXISTS_FIELD_NUMBER = 2; + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + @java.lang.Override + public boolean hasExists() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + + public static final int GETVALUE_FIELD_NUMBER = 3; + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + * @return Whether the getValue field is set. + */ + @java.lang.Override + public boolean hasGetValue() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + * @return The getValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue getGetValue() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder getGetValueOrBuilder() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); + } + + public static final int CONTAINSKEY_FIELD_NUMBER = 4; + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + * @return Whether the containsKey field is set. + */ + @java.lang.Override + public boolean hasContainsKey() { + return methodCase_ == 4; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + * @return The containsKey. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey getContainsKey() { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder getContainsKeyOrBuilder() { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); + } + + public static final int UPDATEVALUE_FIELD_NUMBER = 5; + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + * @return Whether the updateValue field is set. + */ + @java.lang.Override + public boolean hasUpdateValue() { + return methodCase_ == 5; + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + * @return The updateValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue getUpdateValue() { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder getUpdateValueOrBuilder() { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); + } + + public static final int ITERATOR_FIELD_NUMBER = 6; + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + * @return Whether the iterator field is set. + */ + @java.lang.Override + public boolean hasIterator() { + return methodCase_ == 6; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + * @return The iterator. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator getIterator() { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder getIteratorOrBuilder() { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); + } + + public static final int KEYS_FIELD_NUMBER = 7; + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + * @return Whether the keys field is set. + */ + @java.lang.Override + public boolean hasKeys() { + return methodCase_ == 7; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + * @return The keys. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys getKeys() { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder getKeysOrBuilder() { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); + } + + public static final int VALUES_FIELD_NUMBER = 8; + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + * @return Whether the values field is set. + */ + @java.lang.Override + public boolean hasValues() { + return methodCase_ == 8; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + * @return The values. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values getValues() { + if (methodCase_ == 8) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder getValuesOrBuilder() { + if (methodCase_ == 8) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); + } + + public static final int REMOVEKEY_FIELD_NUMBER = 9; + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + * @return Whether the removeKey field is set. + */ + @java.lang.Override + public boolean hasRemoveKey() { + return methodCase_ == 9; + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + * @return The removeKey. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey getRemoveKey() { + if (methodCase_ == 9) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder getRemoveKeyOrBuilder() { + if (methodCase_ == 9) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); + } + + public static final int CLEAR_FIELD_NUMBER = 10; + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + * @return Whether the clear field is set. + */ + @java.lang.Override + public boolean hasClear() { + return methodCase_ == 10; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + * @return The clear. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear() { + if (methodCase_ == 10) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder() { + if (methodCase_ == 10) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(stateName_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, stateName_); + } + if (methodCase_ == 2) { + output.writeMessage(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_); + } + if (methodCase_ == 3) { + output.writeMessage(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_); + } + if (methodCase_ == 4) { + output.writeMessage(4, (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_); + } + if (methodCase_ == 5) { + output.writeMessage(5, (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_); + } + if (methodCase_ == 6) { + output.writeMessage(6, (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_); + } + if (methodCase_ == 7) { + output.writeMessage(7, (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_); + } + if (methodCase_ == 8) { + output.writeMessage(8, (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_); + } + if (methodCase_ == 9) { + output.writeMessage(9, (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_); + } + if (methodCase_ == 10) { + output.writeMessage(10, (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(stateName_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, stateName_); + } + if (methodCase_ == 2) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_); + } + if (methodCase_ == 3) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_); + } + if (methodCase_ == 4) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_); + } + if (methodCase_ == 5) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(5, (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_); + } + if (methodCase_ == 6) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(6, (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_); + } + if (methodCase_ == 7) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(7, (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_); + } + if (methodCase_ == 8) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(8, (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_); + } + if (methodCase_ == 9) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(9, (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_); + } + if (methodCase_ == 10) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(10, (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall other = (org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) obj; + + if (!getStateName() + .equals(other.getStateName())) return false; + if (!getMethodCase().equals(other.getMethodCase())) return false; + switch (methodCase_) { + case 2: + if (!getExists() + .equals(other.getExists())) return false; + break; + case 3: + if (!getGetValue() + .equals(other.getGetValue())) return false; + break; + case 4: + if (!getContainsKey() + .equals(other.getContainsKey())) return false; + break; + case 5: + if (!getUpdateValue() + .equals(other.getUpdateValue())) return false; + break; + case 6: + if (!getIterator() + .equals(other.getIterator())) return false; + break; + case 7: + if (!getKeys() + .equals(other.getKeys())) return false; + break; + case 8: + if (!getValues() + .equals(other.getValues())) return false; + break; + case 9: + if (!getRemoveKey() + .equals(other.getRemoveKey())) return false; + break; + case 10: + if (!getClear() + .equals(other.getClear())) return false; + break; + case 0: + default: + } + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + STATENAME_FIELD_NUMBER; + hash = (53 * hash) + getStateName().hashCode(); + switch (methodCase_) { + case 2: + hash = (37 * hash) + EXISTS_FIELD_NUMBER; + hash = (53 * hash) + getExists().hashCode(); + break; + case 3: + hash = (37 * hash) + GETVALUE_FIELD_NUMBER; + hash = (53 * hash) + getGetValue().hashCode(); + break; + case 4: + hash = (37 * hash) + CONTAINSKEY_FIELD_NUMBER; + hash = (53 * hash) + getContainsKey().hashCode(); + break; + case 5: + hash = (37 * hash) + UPDATEVALUE_FIELD_NUMBER; + hash = (53 * hash) + getUpdateValue().hashCode(); + break; + case 6: + hash = (37 * hash) + ITERATOR_FIELD_NUMBER; + hash = (53 * hash) + getIterator().hashCode(); + break; + case 7: + hash = (37 * hash) + KEYS_FIELD_NUMBER; + hash = (53 * hash) + getKeys().hashCode(); + break; + case 8: + hash = (37 * hash) + VALUES_FIELD_NUMBER; + hash = (53 * hash) + getValues().hashCode(); + break; + case 9: + hash = (37 * hash) + REMOVEKEY_FIELD_NUMBER; + hash = (53 * hash) + getRemoveKey().hashCode(); + break; + case 10: + hash = (37 * hash) + CLEAR_FIELD_NUMBER; + hash = (53 * hash) + getClear().hashCode(); + break; + case 0: + default: + } + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.MapStateCall} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.MapStateCall) + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCallOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.class, org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + stateName_ = ""; + + if (existsBuilder_ != null) { + existsBuilder_.clear(); + } + if (getValueBuilder_ != null) { + getValueBuilder_.clear(); + } + if (containsKeyBuilder_ != null) { + containsKeyBuilder_.clear(); + } + if (updateValueBuilder_ != null) { + updateValueBuilder_.clear(); + } + if (iteratorBuilder_ != null) { + iteratorBuilder_.clear(); + } + if (keysBuilder_ != null) { + keysBuilder_.clear(); + } + if (valuesBuilder_ != null) { + valuesBuilder_.clear(); + } + if (removeKeyBuilder_ != null) { + removeKeyBuilder_.clear(); + } + if (clearBuilder_ != null) { + clearBuilder_.clear(); + } + methodCase_ = 0; + method_ = null; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall(this); + result.stateName_ = stateName_; + if (methodCase_ == 2) { + if (existsBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = existsBuilder_.build(); + } + } + if (methodCase_ == 3) { + if (getValueBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = getValueBuilder_.build(); + } + } + if (methodCase_ == 4) { + if (containsKeyBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = containsKeyBuilder_.build(); + } + } + if (methodCase_ == 5) { + if (updateValueBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = updateValueBuilder_.build(); + } + } + if (methodCase_ == 6) { + if (iteratorBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = iteratorBuilder_.build(); + } + } + if (methodCase_ == 7) { + if (keysBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = keysBuilder_.build(); + } + } + if (methodCase_ == 8) { + if (valuesBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = valuesBuilder_.build(); + } + } + if (methodCase_ == 9) { + if (removeKeyBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = removeKeyBuilder_.build(); + } + } + if (methodCase_ == 10) { + if (clearBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = clearBuilder_.build(); + } + } + result.methodCase_ = methodCase_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall.getDefaultInstance()) return this; + if (!other.getStateName().isEmpty()) { + stateName_ = other.stateName_; + onChanged(); + } + switch (other.getMethodCase()) { + case EXISTS: { + mergeExists(other.getExists()); + break; + } + case GETVALUE: { + mergeGetValue(other.getGetValue()); + break; + } + case CONTAINSKEY: { + mergeContainsKey(other.getContainsKey()); + break; + } + case UPDATEVALUE: { + mergeUpdateValue(other.getUpdateValue()); + break; + } + case ITERATOR: { + mergeIterator(other.getIterator()); + break; + } + case KEYS: { + mergeKeys(other.getKeys()); + break; + } + case VALUES: { + mergeValues(other.getValues()); + break; + } + case REMOVEKEY: { + mergeRemoveKey(other.getRemoveKey()); + break; + } + case CLEAR: { + mergeClear(other.getClear()); + break; + } + case METHOD_NOT_SET: { + break; + } + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + stateName_ = input.readStringRequireUtf8(); + + break; + } // case 10 + case 18: { + input.readMessage( + getExistsFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 2; + break; + } // case 18 + case 26: { + input.readMessage( + getGetValueFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 3; + break; + } // case 26 + case 34: { + input.readMessage( + getContainsKeyFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 4; + break; + } // case 34 + case 42: { + input.readMessage( + getUpdateValueFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 5; + break; + } // case 42 + case 50: { + input.readMessage( + getIteratorFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 6; + break; + } // case 50 + case 58: { + input.readMessage( + getKeysFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 7; + break; + } // case 58 + case 66: { + input.readMessage( + getValuesFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 8; + break; + } // case 66 + case 74: { + input.readMessage( + getRemoveKeyFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 9; + break; + } // case 74 + case 82: { + input.readMessage( + getClearFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 10; + break; + } // case 82 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + private int methodCase_ = 0; + private java.lang.Object method_; + public MethodCase + getMethodCase() { + return MethodCase.forNumber( + methodCase_); + } + + public Builder clearMethod() { + methodCase_ = 0; + method_ = null; + onChanged(); + return this; + } + + + private java.lang.Object stateName_ = ""; + /** + * string stateName = 1; + * @return The stateName. + */ + public java.lang.String getStateName() { + java.lang.Object ref = stateName_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + stateName_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + public com.google.protobuf.ByteString + getStateNameBytes() { + java.lang.Object ref = stateName_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + stateName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string stateName = 1; + * @param value The stateName to set. + * @return This builder for chaining. + */ + public Builder setStateName( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + stateName_ = value; + onChanged(); + return this; + } + /** + * string stateName = 1; + * @return This builder for chaining. + */ + public Builder clearStateName() { + + stateName_ = getDefaultInstance().getStateName(); + onChanged(); + return this; + } + /** + * string stateName = 1; + * @param value The bytes for stateName to set. + * @return This builder for chaining. + */ + public Builder setStateNameBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + stateName_ = value; + onChanged(); + return this; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder> existsBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + @java.lang.Override + public boolean hasExists() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists() { + if (existsBuilder_ == null) { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } else { + if (methodCase_ == 2) { + return existsBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder setExists(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists value) { + if (existsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + existsBuilder_.setMessage(value); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder setExists( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder builderForValue) { + if (existsBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + existsBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder mergeExists(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists value) { + if (existsBuilder_ == null) { + if (methodCase_ == 2 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 2) { + existsBuilder_.mergeFrom(value); + } else { + existsBuilder_.setMessage(value); + } + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder clearExists() { + if (existsBuilder_ == null) { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + } + existsBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder getExistsBuilder() { + return getExistsFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder() { + if ((methodCase_ == 2) && (existsBuilder_ != null)) { + return existsBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder> + getExistsFieldBuilder() { + if (existsBuilder_ == null) { + if (!(methodCase_ == 2)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + existsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 2; + onChanged();; + return existsBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder> getValueBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + * @return Whether the getValue field is set. + */ + @java.lang.Override + public boolean hasGetValue() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + * @return The getValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue getGetValue() { + if (getValueBuilder_ == null) { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); + } else { + if (methodCase_ == 3) { + return getValueBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + public Builder setGetValue(org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue value) { + if (getValueBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + getValueBuilder_.setMessage(value); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + public Builder setGetValue( + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder builderForValue) { + if (getValueBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + getValueBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + public Builder mergeGetValue(org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue value) { + if (getValueBuilder_ == null) { + if (methodCase_ == 3 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 3) { + getValueBuilder_.mergeFrom(value); + } else { + getValueBuilder_.setMessage(value); + } + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + public Builder clearGetValue() { + if (getValueBuilder_ == null) { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + } + getValueBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder getGetValueBuilder() { + return getGetValueFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder getGetValueOrBuilder() { + if ((methodCase_ == 3) && (getValueBuilder_ != null)) { + return getValueBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.GetValue getValue = 3; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder> + getGetValueFieldBuilder() { + if (getValueBuilder_ == null) { + if (!(methodCase_ == 3)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); + } + getValueBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 3; + onChanged();; + return getValueBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder> containsKeyBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + * @return Whether the containsKey field is set. + */ + @java.lang.Override + public boolean hasContainsKey() { + return methodCase_ == 4; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + * @return The containsKey. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey getContainsKey() { + if (containsKeyBuilder_ == null) { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); + } else { + if (methodCase_ == 4) { + return containsKeyBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + public Builder setContainsKey(org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey value) { + if (containsKeyBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + containsKeyBuilder_.setMessage(value); + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + public Builder setContainsKey( + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder builderForValue) { + if (containsKeyBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + containsKeyBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + public Builder mergeContainsKey(org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey value) { + if (containsKeyBuilder_ == null) { + if (methodCase_ == 4 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 4) { + containsKeyBuilder_.mergeFrom(value); + } else { + containsKeyBuilder_.setMessage(value); + } + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + public Builder clearContainsKey() { + if (containsKeyBuilder_ == null) { + if (methodCase_ == 4) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 4) { + methodCase_ = 0; + method_ = null; + } + containsKeyBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder getContainsKeyBuilder() { + return getContainsKeyFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder getContainsKeyOrBuilder() { + if ((methodCase_ == 4) && (containsKeyBuilder_ != null)) { + return containsKeyBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ContainsKey containsKey = 4; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder> + getContainsKeyFieldBuilder() { + if (containsKeyBuilder_ == null) { + if (!(methodCase_ == 4)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); + } + containsKeyBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 4; + onChanged();; + return containsKeyBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder> updateValueBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + * @return Whether the updateValue field is set. + */ + @java.lang.Override + public boolean hasUpdateValue() { + return methodCase_ == 5; + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + * @return The updateValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue getUpdateValue() { + if (updateValueBuilder_ == null) { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); + } else { + if (methodCase_ == 5) { + return updateValueBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + public Builder setUpdateValue(org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue value) { + if (updateValueBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + updateValueBuilder_.setMessage(value); + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + public Builder setUpdateValue( + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder builderForValue) { + if (updateValueBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + updateValueBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + public Builder mergeUpdateValue(org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue value) { + if (updateValueBuilder_ == null) { + if (methodCase_ == 5 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 5) { + updateValueBuilder_.mergeFrom(value); + } else { + updateValueBuilder_.setMessage(value); + } + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + public Builder clearUpdateValue() { + if (updateValueBuilder_ == null) { + if (methodCase_ == 5) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 5) { + methodCase_ = 0; + method_ = null; + } + updateValueBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder getUpdateValueBuilder() { + return getUpdateValueFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder getUpdateValueOrBuilder() { + if ((methodCase_ == 5) && (updateValueBuilder_ != null)) { + return updateValueBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.UpdateValue updateValue = 5; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder> + getUpdateValueFieldBuilder() { + if (updateValueBuilder_ == null) { + if (!(methodCase_ == 5)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); + } + updateValueBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 5; + onChanged();; + return updateValueBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator, org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder> iteratorBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + * @return Whether the iterator field is set. + */ + @java.lang.Override + public boolean hasIterator() { + return methodCase_ == 6; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + * @return The iterator. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator getIterator() { + if (iteratorBuilder_ == null) { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); + } else { + if (methodCase_ == 6) { + return iteratorBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + public Builder setIterator(org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator value) { + if (iteratorBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + iteratorBuilder_.setMessage(value); + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + public Builder setIterator( + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder builderForValue) { + if (iteratorBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + iteratorBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + public Builder mergeIterator(org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator value) { + if (iteratorBuilder_ == null) { + if (methodCase_ == 6 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 6) { + iteratorBuilder_.mergeFrom(value); + } else { + iteratorBuilder_.setMessage(value); + } + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + public Builder clearIterator() { + if (iteratorBuilder_ == null) { + if (methodCase_ == 6) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 6) { + methodCase_ = 0; + method_ = null; + } + iteratorBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder getIteratorBuilder() { + return getIteratorFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder getIteratorOrBuilder() { + if ((methodCase_ == 6) && (iteratorBuilder_ != null)) { + return iteratorBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Iterator iterator = 6; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator, org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder> + getIteratorFieldBuilder() { + if (iteratorBuilder_ == null) { + if (!(methodCase_ == 6)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); + } + iteratorBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator, org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 6; + onChanged();; + return iteratorBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys, org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder> keysBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + * @return Whether the keys field is set. + */ + @java.lang.Override + public boolean hasKeys() { + return methodCase_ == 7; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + * @return The keys. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys getKeys() { + if (keysBuilder_ == null) { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); + } else { + if (methodCase_ == 7) { + return keysBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + public Builder setKeys(org.apache.spark.sql.execution.streaming.state.StateMessage.Keys value) { + if (keysBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + keysBuilder_.setMessage(value); + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + public Builder setKeys( + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder builderForValue) { + if (keysBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + keysBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + public Builder mergeKeys(org.apache.spark.sql.execution.streaming.state.StateMessage.Keys value) { + if (keysBuilder_ == null) { + if (methodCase_ == 7 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 7) { + keysBuilder_.mergeFrom(value); + } else { + keysBuilder_.setMessage(value); + } + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + public Builder clearKeys() { + if (keysBuilder_ == null) { + if (methodCase_ == 7) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 7) { + methodCase_ = 0; + method_ = null; + } + keysBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder getKeysBuilder() { + return getKeysFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder getKeysOrBuilder() { + if ((methodCase_ == 7) && (keysBuilder_ != null)) { + return keysBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Keys keys = 7; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys, org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder> + getKeysFieldBuilder() { + if (keysBuilder_ == null) { + if (!(methodCase_ == 7)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); + } + keysBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys, org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 7; + onChanged();; + return keysBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Values, org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder> valuesBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + * @return Whether the values field is set. + */ + @java.lang.Override + public boolean hasValues() { + return methodCase_ == 8; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + * @return The values. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values getValues() { + if (valuesBuilder_ == null) { + if (methodCase_ == 8) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); + } else { + if (methodCase_ == 8) { + return valuesBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + public Builder setValues(org.apache.spark.sql.execution.streaming.state.StateMessage.Values value) { + if (valuesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + valuesBuilder_.setMessage(value); + } + methodCase_ = 8; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + public Builder setValues( + org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder builderForValue) { + if (valuesBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + valuesBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 8; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + public Builder mergeValues(org.apache.spark.sql.execution.streaming.state.StateMessage.Values value) { + if (valuesBuilder_ == null) { + if (methodCase_ == 8 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Values.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 8) { + valuesBuilder_.mergeFrom(value); + } else { + valuesBuilder_.setMessage(value); + } + } + methodCase_ = 8; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + public Builder clearValues() { + if (valuesBuilder_ == null) { + if (methodCase_ == 8) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 8) { + methodCase_ = 0; + method_ = null; + } + valuesBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder getValuesBuilder() { + return getValuesFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder getValuesOrBuilder() { + if ((methodCase_ == 8) && (valuesBuilder_ != null)) { + return valuesBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 8) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Values values = 8; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Values, org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder> + getValuesFieldBuilder() { + if (valuesBuilder_ == null) { + if (!(methodCase_ == 8)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); + } + valuesBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Values, org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 8; + onChanged();; + return valuesBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder> removeKeyBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + * @return Whether the removeKey field is set. + */ + @java.lang.Override + public boolean hasRemoveKey() { + return methodCase_ == 9; + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + * @return The removeKey. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey getRemoveKey() { + if (removeKeyBuilder_ == null) { + if (methodCase_ == 9) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); + } else { + if (methodCase_ == 9) { + return removeKeyBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + public Builder setRemoveKey(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey value) { + if (removeKeyBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + removeKeyBuilder_.setMessage(value); + } + methodCase_ = 9; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + public Builder setRemoveKey( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder builderForValue) { + if (removeKeyBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + removeKeyBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 9; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + public Builder mergeRemoveKey(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey value) { + if (removeKeyBuilder_ == null) { + if (methodCase_ == 9 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 9) { + removeKeyBuilder_.mergeFrom(value); + } else { + removeKeyBuilder_.setMessage(value); + } + } + methodCase_ = 9; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + public Builder clearRemoveKey() { + if (removeKeyBuilder_ == null) { + if (methodCase_ == 9) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 9) { + methodCase_ = 0; + method_ = null; + } + removeKeyBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder getRemoveKeyBuilder() { + return getRemoveKeyFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder getRemoveKeyOrBuilder() { + if ((methodCase_ == 9) && (removeKeyBuilder_ != null)) { + return removeKeyBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 9) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.RemoveKey removeKey = 9; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder> + getRemoveKeyFieldBuilder() { + if (removeKeyBuilder_ == null) { + if (!(methodCase_ == 9)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); + } + removeKeyBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 9; + onChanged();; + return removeKeyBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder> clearBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + * @return Whether the clear field is set. + */ + @java.lang.Override + public boolean hasClear() { + return methodCase_ == 10; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + * @return The clear. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear() { + if (clearBuilder_ == null) { + if (methodCase_ == 10) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } else { + if (methodCase_ == 10) { + return clearBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + public Builder setClear(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear value) { + if (clearBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + clearBuilder_.setMessage(value); + } + methodCase_ = 10; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + public Builder setClear( + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder builderForValue) { + if (clearBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + clearBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 10; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + public Builder mergeClear(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear value) { + if (clearBuilder_ == null) { + if (methodCase_ == 10 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 10) { + clearBuilder_.mergeFrom(value); + } else { + clearBuilder_.setMessage(value); + } + } + methodCase_ = 10; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + public Builder clearClear() { + if (clearBuilder_ == null) { + if (methodCase_ == 10) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 10) { + methodCase_ = 0; + method_ = null; + } + clearBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder getClearBuilder() { + return getClearFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder() { + if ((methodCase_ == 10) && (clearBuilder_ != null)) { + return clearBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 10) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 10; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder> + getClearFieldBuilder() { + if (clearBuilder_ == null) { + if (!(methodCase_ == 10)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + clearBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 10; + onChanged();; + return clearBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.MapStateCall) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.MapStateCall) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public MapStateCall parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.MapStateCall getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface SetImplicitKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + com.google.protobuf.MessageOrBuilder { + + /** + * bytes key = 1; + * @return The key. + */ + com.google.protobuf.ByteString getKey(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + */ + public static final class SetImplicitKey extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + SetImplicitKeyOrBuilder { + private static final long serialVersionUID = 0L; + // Use SetImplicitKey.newBuilder() to construct. + private SetImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private SetImplicitKey() { + key_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SetImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + } + + public static final int KEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString key_; + /** + * bytes key = 1; + * @return The key. + */ + @java.lang.Override + public com.google.protobuf.ByteString getKey() { + return key_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!key_.isEmpty()) { + output.writeBytes(1, key_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!key_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, key_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) obj; + + if (!getKey() + .equals(other.getKey())) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + KEY_FIELD_NUMBER; + hash = (53 * hash) + getKey().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + key_ = com.google.protobuf.ByteString.EMPTY; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); + result.key_ = key_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance()) return this; + if (other.getKey() != com.google.protobuf.ByteString.EMPTY) { + setKey(other.getKey()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + key_ = input.readBytes(); + + break; + } // case 10 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + + private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes key = 1; + * @return The key. + */ + @java.lang.Override + public com.google.protobuf.ByteString getKey() { + return key_; + } + /** + * bytes key = 1; + * @param value The key to set. + * @return This builder for chaining. + */ + public Builder setKey(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + key_ = value; + onChanged(); + return this; + } + /** + * bytes key = 1; + * @return This builder for chaining. + */ + public Builder clearKey() { + + key_ = getDefaultInstance().getKey(); + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public SetImplicitKey parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface RemoveImplicitKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + */ + public static final class RemoveImplicitKey extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + RemoveImplicitKeyOrBuilder { + private static final long serialVersionUID = 0L; + // Use RemoveImplicitKey.newBuilder() to construct. + private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private RemoveImplicitKey() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new RemoveImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public RemoveImplicitKey parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ExistsOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Exists) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + */ + public static final class Exists extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Exists) + ExistsOrBuilder { + private static final long serialVersionUID = 0L; + // Use Exists.newBuilder() to construct. + private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Exists() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Exists(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Exists) + org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Exists) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Exists DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Exists parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface GetOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Get) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + */ + public static final class Get extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Get) + GetOrBuilder { + private static final long serialVersionUID = 0L; + // Use Get.newBuilder() to construct. + private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Get() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Get(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Get other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Get) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Get prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Get) + org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Get.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Get)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Get other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Get) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Get DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Get parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ValueStateUpdateOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + com.google.protobuf.MessageOrBuilder { + + /** + * bytes value = 1; + * @return The value. + */ + com.google.protobuf.ByteString getValue(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + */ + public static final class ValueStateUpdate extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + ValueStateUpdateOrBuilder { + private static final long serialVersionUID = 0L; + // Use ValueStateUpdate.newBuilder() to construct. + private ValueStateUpdate(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ValueStateUpdate() { + value_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ValueStateUpdate(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + } + + public static final int VALUE_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString value_; + /** + * bytes value = 1; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!value_.isEmpty()) { + output.writeBytes(1, value_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!value_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, value_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) obj; + + if (!getValue() + .equals(other.getValue())) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + VALUE_FIELD_NUMBER; + hash = (53 * hash) + getValue().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdateOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + value_ = com.google.protobuf.ByteString.EMPTY; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); + result.value_ = value_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance()) return this; + if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { + setValue(other.getValue()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + value_ = input.readBytes(); + + break; + } // case 10 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + + private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes value = 1; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; + } + /** + * bytes value = 1; + * @param value The value to set. + * @return This builder for chaining. + */ + public Builder setValue(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + value_ = value; + onChanged(); + return this; + } + /** + * bytes value = 1; + * @return This builder for chaining. + */ + public Builder clearValue() { + + value_ = getDefaultInstance().getValue(); + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ValueStateUpdate parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ClearOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Clear) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + */ + public static final class Clear extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Clear) + ClearOrBuilder { + private static final long serialVersionUID = 0L; + // Use Clear.newBuilder() to construct. + private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Clear() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Clear(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Clear) + org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Clear) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Clear DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Clear parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ListStateGetOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStateGet) + com.google.protobuf.MessageOrBuilder { + + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + java.lang.String getIteratorId(); + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + com.google.protobuf.ByteString + getIteratorIdBytes(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} + */ + public static final class ListStateGet extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) + ListStateGetOrBuilder { + private static final long serialVersionUID = 0L; + // Use ListStateGet.newBuilder() to construct. + private ListStateGet(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ListStateGet() { + iteratorId_ = ""; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ListStateGet(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); + } + + public static final int ITERATORID_FIELD_NUMBER = 1; + private volatile java.lang.Object iteratorId_; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + @java.lang.Override + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, iteratorId_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, iteratorId_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) obj; + + if (!getIteratorId() + .equals(other.getIteratorId())) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + ITERATORID_FIELD_NUMBER; + hash = (53 * hash) + getIteratorId().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + iteratorId_ = ""; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(this); + result.iteratorId_ = iteratorId_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance()) return this; + if (!other.getIteratorId().isEmpty()) { + iteratorId_ = other.iteratorId_; + onChanged(); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + iteratorId_ = input.readStringRequireUtf8(); + + break; + } // case 10 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + + private java.lang.Object iteratorId_ = ""; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string iteratorId = 1; + * @param value The iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + iteratorId_ = value; + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @return This builder for chaining. + */ + public Builder clearIteratorId() { + + iteratorId_ = getDefaultInstance().getIteratorId(); + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @param value The bytes for iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + iteratorId_ = value; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) } - private SetImplicitKey() { - key_ = com.google.protobuf.ByteString.EMPTY; + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ListStateGet parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ListStatePutOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStatePut) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} + */ + public static final class ListStatePut extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) + ListStatePutOrBuilder { + private static final long serialVersionUID = 0L; + // Use ListStatePut.newBuilder() to construct. + private ListStatePut(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ListStatePut() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new SetImplicitKey(); + return new ListStatePut(); } @java.lang.Override @@ -9724,26 +15843,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); - } - - public static final int KEY_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString key_; - /** - * bytes key = 1; - * @return The key. - */ - @java.lang.Override - public com.google.protobuf.ByteString getKey() { - return key_; + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); } private byte memoizedIsInitialized = -1; @@ -9760,9 +15868,6 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { - if (!key_.isEmpty()) { - output.writeBytes(1, key_); - } getUnknownFields().writeTo(output); } @@ -9772,10 +15877,6 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; - if (!key_.isEmpty()) { - size += com.google.protobuf.CodedOutputStream - .computeBytesSize(1, key_); - } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -9786,13 +15887,11 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) obj; - if (!getKey() - .equals(other.getKey())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -9804,76 +15903,74 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + KEY_FIELD_NUMBER; - hash = (53 * hash) + getKey().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -9886,7 +15983,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImp public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -9902,26 +15999,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.newBuilder() private Builder() { } @@ -9934,25 +16031,23 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - key_ = com.google.protobuf.ByteString.EMPTY; - return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9960,9 +16055,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); - result.key_ = key_; + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(this); onBuilt(); return result; } @@ -10001,19 +16095,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance()) return this; - if (other.getKey() != com.google.protobuf.ByteString.EMPTY) { - setKey(other.getKey()); - } + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -10040,11 +16131,6 @@ public Builder mergeFrom( case 0: done = true; break; - case 10: { - key_ = input.readBytes(); - - break; - } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -10060,40 +16146,6 @@ public Builder mergeFrom( } // finally return this; } - - private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; - /** - * bytes key = 1; - * @return The key. - */ - @java.lang.Override - public com.google.protobuf.ByteString getKey() { - return key_; - } - /** - * bytes key = 1; - * @param value The key to set. - * @return This builder for chaining. - */ - public Builder setKey(com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - - key_ = value; - onChanged(); - return this; - } - /** - * bytes key = 1; - * @return This builder for chaining. - */ - public Builder clearKey() { - - key_ = getDefaultInstance().getKey(); - onChanged(); - return this; - } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -10107,23 +16159,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public SetImplicitKey parsePartialFrom( + public ListStatePut parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -10142,46 +16194,53 @@ public SetImplicitKey parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface RemoveImplicitKeyOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + public interface AppendValueOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendValue) com.google.protobuf.MessageOrBuilder { + + /** + * bytes value = 1; + * @return The value. + */ + com.google.protobuf.ByteString getValue(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} */ - public static final class RemoveImplicitKey extends + public static final class AppendValue extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - RemoveImplicitKeyOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) + AppendValueOrBuilder { private static final long serialVersionUID = 0L; - // Use RemoveImplicitKey.newBuilder() to construct. - private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use AppendValue.newBuilder() to construct. + private AppendValue(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private RemoveImplicitKey() { + private AppendValue() { + value_ = com.google.protobuf.ByteString.EMPTY; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new RemoveImplicitKey(); + return new AppendValue(); } @java.lang.Override @@ -10191,15 +16250,26 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); + } + + public static final int VALUE_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString value_; + /** + * bytes value = 1; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; } private byte memoizedIsInitialized = -1; @@ -10216,6 +16286,9 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!value_.isEmpty()) { + output.writeBytes(1, value_); + } getUnknownFields().writeTo(output); } @@ -10225,6 +16298,10 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!value_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, value_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -10235,11 +16312,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) obj; + if (!getValue() + .equals(other.getValue())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -10251,74 +16330,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + VALUE_FIELD_NUMBER; + hash = (53 * hash) + getValue().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -10331,7 +16412,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Remove public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -10347,26 +16428,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.newBuilder() private Builder() { } @@ -10379,23 +16460,25 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + value_ = com.google.protobuf.ByteString.EMPTY; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -10403,8 +16486,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(this); + result.value_ = value_; onBuilt(); return result; } @@ -10443,16 +16527,19 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance()) return this; + if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { + setValue(other.getValue()); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -10479,6 +16566,11 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + value_ = input.readBytes(); + + break; + } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -10494,6 +16586,40 @@ public Builder mergeFrom( } // finally return this; } + + private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes value = 1; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; + } + /** + * bytes value = 1; + * @param value The value to set. + * @return This builder for chaining. + */ + public Builder setValue(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + value_ = value; + onChanged(); + return this; + } + /** + * bytes value = 1; + * @return This builder for chaining. + */ + public Builder clearValue() { + + value_ = getDefaultInstance().getValue(); + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -10507,23 +16633,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public RemoveImplicitKey parsePartialFrom( + public AppendValue parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -10542,46 +16668,46 @@ public RemoveImplicitKey parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ExistsOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Exists) + public interface AppendListOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendList) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} */ - public static final class Exists extends + public static final class AppendList extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Exists) - ExistsOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendList) + AppendListOrBuilder { private static final long serialVersionUID = 0L; - // Use Exists.newBuilder() to construct. - private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use AppendList.newBuilder() to construct. + private AppendList(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Exists() { + private AppendList() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Exists(); + return new AppendList(); } @java.lang.Override @@ -10591,15 +16717,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); } private byte memoizedIsInitialized = -1; @@ -10635,10 +16761,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -10656,69 +16782,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -10731,7 +16857,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -10747,26 +16873,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Exists) - org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendList) + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.newBuilder() private Builder() { } @@ -10785,17 +16911,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } - @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -10803,8 +16929,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build( } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(this); onBuilt(); return result; } @@ -10843,16 +16969,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -10907,23 +17033,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendList) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Exists) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Exists DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendList) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Exists parsePartialFrom( + public AppendList parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -10942,46 +17068,53 @@ public Exists parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface GetOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Get) + public interface GetValueOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.GetValue) com.google.protobuf.MessageOrBuilder { + + /** + * bytes userKey = 1; + * @return The userKey. + */ + com.google.protobuf.ByteString getUserKey(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.GetValue} */ - public static final class Get extends + public static final class GetValue extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Get) - GetOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.GetValue) + GetValueOrBuilder { private static final long serialVersionUID = 0L; - // Use Get.newBuilder() to construct. - private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use GetValue.newBuilder() to construct. + private GetValue(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Get() { + private GetValue() { + userKey_ = com.google.protobuf.ByteString.EMPTY; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Get(); + return new GetValue(); } @java.lang.Override @@ -10991,15 +17124,26 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder.class); + } + + public static final int USERKEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString userKey_; + /** + * bytes userKey = 1; + * @return The userKey. + */ + @java.lang.Override + public com.google.protobuf.ByteString getUserKey() { + return userKey_; } private byte memoizedIsInitialized = -1; @@ -11016,6 +17160,9 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!userKey_.isEmpty()) { + output.writeBytes(1, userKey_); + } getUnknownFields().writeTo(output); } @@ -11025,6 +17172,10 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!userKey_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, userKey_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -11035,11 +17186,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Get other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Get) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue other = (org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) obj; + if (!getUserKey() + .equals(other.getUserKey())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -11051,74 +17204,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + USERKEY_FIELD_NUMBER; + hash = (53 * hash) + getUserKey().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -11131,7 +17286,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get pa public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Get prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -11147,26 +17302,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.GetValue} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Get) - org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.GetValue) + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValueOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Get.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.newBuilder() private Builder() { } @@ -11179,23 +17334,25 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + userKey_ = com.google.protobuf.ByteString.EMPTY; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -11203,8 +17360,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue result = new org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue(this); + result.userKey_ = userKey_; onBuilt(); return result; } @@ -11243,16 +17401,19 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Get)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Get other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue.getDefaultInstance()) return this; + if (other.getUserKey() != com.google.protobuf.ByteString.EMPTY) { + setUserKey(other.getUserKey()); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -11279,6 +17440,11 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + userKey_ = input.readBytes(); + + break; + } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -11294,6 +17460,40 @@ public Builder mergeFrom( } // finally return this; } + + private com.google.protobuf.ByteString userKey_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes userKey = 1; + * @return The userKey. + */ + @java.lang.Override + public com.google.protobuf.ByteString getUserKey() { + return userKey_; + } + /** + * bytes userKey = 1; + * @param value The userKey to set. + * @return This builder for chaining. + */ + public Builder setUserKey(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + userKey_ = value; + onChanged(); + return this; + } + /** + * bytes userKey = 1; + * @return This builder for chaining. + */ + public Builder clearUserKey() { + + userKey_ = getDefaultInstance().getUserKey(); + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -11307,23 +17507,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.GetValue) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Get) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Get DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.GetValue) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Get parsePartialFrom( + public GetValue parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -11342,53 +17542,53 @@ public Get parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.GetValue getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ValueStateUpdateOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + public interface ContainsKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ContainsKey) com.google.protobuf.MessageOrBuilder { /** - * bytes value = 1; - * @return The value. + * bytes userKey = 1; + * @return The userKey. */ - com.google.protobuf.ByteString getValue(); + com.google.protobuf.ByteString getUserKey(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ContainsKey} */ - public static final class ValueStateUpdate extends + public static final class ContainsKey extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - ValueStateUpdateOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ContainsKey) + ContainsKeyOrBuilder { private static final long serialVersionUID = 0L; - // Use ValueStateUpdate.newBuilder() to construct. - private ValueStateUpdate(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ContainsKey.newBuilder() to construct. + private ContainsKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private ValueStateUpdate() { - value_ = com.google.protobuf.ByteString.EMPTY; + private ContainsKey() { + userKey_ = com.google.protobuf.ByteString.EMPTY; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new ValueStateUpdate(); + return new ContainsKey(); } @java.lang.Override @@ -11398,26 +17598,26 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder.class); } - public static final int VALUE_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString value_; + public static final int USERKEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString userKey_; /** - * bytes value = 1; - * @return The value. + * bytes userKey = 1; + * @return The userKey. */ @java.lang.Override - public com.google.protobuf.ByteString getValue() { - return value_; + public com.google.protobuf.ByteString getUserKey() { + return userKey_; } private byte memoizedIsInitialized = -1; @@ -11434,8 +17634,8 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { - if (!value_.isEmpty()) { - output.writeBytes(1, value_); + if (!userKey_.isEmpty()) { + output.writeBytes(1, userKey_); } getUnknownFields().writeTo(output); } @@ -11446,9 +17646,9 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; - if (!value_.isEmpty()) { + if (!userKey_.isEmpty()) { size += com.google.protobuf.CodedOutputStream - .computeBytesSize(1, value_); + .computeBytesSize(1, userKey_); } size += getUnknownFields().getSerializedSize(); memoizedSize = size; @@ -11460,13 +17660,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) obj; - if (!getValue() - .equals(other.getValue())) return false; + if (!getUserKey() + .equals(other.getUserKey())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -11478,76 +17678,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + VALUE_FIELD_NUMBER; - hash = (53 * hash) + getValue().hashCode(); + hash = (37 * hash) + USERKEY_FIELD_NUMBER; + hash = (53 * hash) + getUserKey().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -11560,7 +17760,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -11576,26 +17776,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ContainsKey} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdateOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ContainsKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKeyOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.newBuilder() private Builder() { } @@ -11608,7 +17808,7 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - value_ = com.google.protobuf.ByteString.EMPTY; + userKey_ = com.google.protobuf.ByteString.EMPTY; return this; } @@ -11616,17 +17816,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -11634,9 +17834,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); - result.value_ = value_; + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey(this); + result.userKey_ = userKey_; onBuilt(); return result; } @@ -11675,18 +17875,18 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance()) return this; - if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { - setValue(other.getValue()); + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey.getDefaultInstance()) return this; + if (other.getUserKey() != com.google.protobuf.ByteString.EMPTY) { + setUserKey(other.getUserKey()); } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); @@ -11715,7 +17915,7 @@ public Builder mergeFrom( done = true; break; case 10: { - value_ = input.readBytes(); + userKey_ = input.readBytes(); break; } // case 10 @@ -11735,36 +17935,36 @@ public Builder mergeFrom( return this; } - private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + private com.google.protobuf.ByteString userKey_ = com.google.protobuf.ByteString.EMPTY; /** - * bytes value = 1; - * @return The value. + * bytes userKey = 1; + * @return The userKey. */ @java.lang.Override - public com.google.protobuf.ByteString getValue() { - return value_; + public com.google.protobuf.ByteString getUserKey() { + return userKey_; } /** - * bytes value = 1; - * @param value The value to set. + * bytes userKey = 1; + * @param value The userKey to set. * @return This builder for chaining. */ - public Builder setValue(com.google.protobuf.ByteString value) { + public Builder setUserKey(com.google.protobuf.ByteString value) { if (value == null) { throw new NullPointerException(); } - value_ = value; + userKey_ = value; onChanged(); return this; } /** - * bytes value = 1; + * bytes userKey = 1; * @return This builder for chaining. */ - public Builder clearValue() { + public Builder clearUserKey() { - value_ = getDefaultInstance().getValue(); + userKey_ = getDefaultInstance().getUserKey(); onChanged(); return this; } @@ -11781,23 +17981,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ContainsKey) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ContainsKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public ValueStateUpdate parsePartialFrom( + public ContainsKey parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -11816,46 +18016,60 @@ public ValueStateUpdate parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ContainsKey getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ClearOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Clear) + public interface UpdateValueOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.UpdateValue) com.google.protobuf.MessageOrBuilder { + + /** + * bytes userKey = 1; + * @return The userKey. + */ + com.google.protobuf.ByteString getUserKey(); + + /** + * bytes value = 2; + * @return The value. + */ + com.google.protobuf.ByteString getValue(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.UpdateValue} */ - public static final class Clear extends + public static final class UpdateValue extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Clear) - ClearOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.UpdateValue) + UpdateValueOrBuilder { private static final long serialVersionUID = 0L; - // Use Clear.newBuilder() to construct. - private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use UpdateValue.newBuilder() to construct. + private UpdateValue(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Clear() { + private UpdateValue() { + userKey_ = com.google.protobuf.ByteString.EMPTY; + value_ = com.google.protobuf.ByteString.EMPTY; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Clear(); + return new UpdateValue(); } @java.lang.Override @@ -11865,15 +18079,37 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder.class); + } + + public static final int USERKEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString userKey_; + /** + * bytes userKey = 1; + * @return The userKey. + */ + @java.lang.Override + public com.google.protobuf.ByteString getUserKey() { + return userKey_; } + public static final int VALUE_FIELD_NUMBER = 2; + private com.google.protobuf.ByteString value_; + /** + * bytes value = 2; + * @return The value. + */ @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable - .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + public com.google.protobuf.ByteString getValue() { + return value_; } private byte memoizedIsInitialized = -1; @@ -11890,6 +18126,12 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!userKey_.isEmpty()) { + output.writeBytes(1, userKey_); + } + if (!value_.isEmpty()) { + output.writeBytes(2, value_); + } getUnknownFields().writeTo(output); } @@ -11899,6 +18141,14 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!userKey_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, userKey_); + } + if (!value_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(2, value_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -11909,11 +18159,15 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue other = (org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) obj; + if (!getUserKey() + .equals(other.getUserKey())) return false; + if (!getValue() + .equals(other.getValue())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -11925,74 +18179,78 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + USERKEY_FIELD_NUMBER; + hash = (53 * hash) + getUserKey().hashCode(); + hash = (37 * hash) + VALUE_FIELD_NUMBER; + hash = (53 * hash) + getValue().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -12005,7 +18263,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -12021,26 +18279,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.UpdateValue} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Clear) - org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.UpdateValue) + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValueOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.newBuilder() private Builder() { } @@ -12053,23 +18311,27 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + userKey_ = com.google.protobuf.ByteString.EMPTY; + + value_ = com.google.protobuf.ByteString.EMPTY; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -12077,8 +18339,10 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue result = new org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue(this); + result.userKey_ = userKey_; + result.value_ = value_; onBuilt(); return result; } @@ -12117,16 +18381,22 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue.getDefaultInstance()) return this; + if (other.getUserKey() != com.google.protobuf.ByteString.EMPTY) { + setUserKey(other.getUserKey()); + } + if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { + setValue(other.getValue()); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -12153,6 +18423,16 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + userKey_ = input.readBytes(); + + break; + } // case 10 + case 18: { + value_ = input.readBytes(); + + break; + } // case 18 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -12168,6 +18448,74 @@ public Builder mergeFrom( } // finally return this; } + + private com.google.protobuf.ByteString userKey_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes userKey = 1; + * @return The userKey. + */ + @java.lang.Override + public com.google.protobuf.ByteString getUserKey() { + return userKey_; + } + /** + * bytes userKey = 1; + * @param value The userKey to set. + * @return This builder for chaining. + */ + public Builder setUserKey(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + userKey_ = value; + onChanged(); + return this; + } + /** + * bytes userKey = 1; + * @return This builder for chaining. + */ + public Builder clearUserKey() { + + userKey_ = getDefaultInstance().getUserKey(); + onChanged(); + return this; + } + + private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes value = 2; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; + } + /** + * bytes value = 2; + * @param value The value to set. + * @return This builder for chaining. + */ + public Builder setValue(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + value_ = value; + onChanged(); + return this; + } + /** + * bytes value = 2; + * @return This builder for chaining. + */ + public Builder clearValue() { + + value_ = getDefaultInstance().getValue(); + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -12181,23 +18529,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.UpdateValue) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Clear) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Clear DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.UpdateValue) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Clear parsePartialFrom( + public UpdateValue parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -12216,24 +18564,24 @@ public Clear parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.UpdateValue getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ListStateGetOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStateGet) + public interface IteratorOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Iterator) com.google.protobuf.MessageOrBuilder { /** @@ -12249,18 +18597,18 @@ public interface ListStateGetOrBuilder extends getIteratorIdBytes(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Iterator} */ - public static final class ListStateGet extends + public static final class Iterator extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) - ListStateGetOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Iterator) + IteratorOrBuilder { private static final long serialVersionUID = 0L; - // Use ListStateGet.newBuilder() to construct. - private ListStateGet(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use Iterator.newBuilder() to construct. + private Iterator(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private ListStateGet() { + private Iterator() { iteratorId_ = ""; } @@ -12268,7 +18616,7 @@ private ListStateGet() { @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new ListStateGet(); + return new Iterator(); } @java.lang.Override @@ -12278,15 +18626,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder.class); } public static final int ITERATORID_FIELD_NUMBER = 1; @@ -12366,10 +18714,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) obj; if (!getIteratorId() .equals(other.getIteratorId())) return false; @@ -12391,69 +18739,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -12466,7 +18814,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListSt public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -12482,26 +18830,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Iterator} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Iterator) + org.apache.spark.sql.execution.streaming.state.StateMessage.IteratorOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.newBuilder() private Builder() { } @@ -12522,17 +18870,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -12540,8 +18888,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator(this); result.iteratorId_ = iteratorId_; onBuilt(); return result; @@ -12581,16 +18929,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator.getDefaultInstance()) return this; if (!other.getIteratorId().isEmpty()) { iteratorId_ = other.iteratorId_; onChanged(); @@ -12730,23 +19078,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Iterator) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Iterator) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public ListStateGet parsePartialFrom( + public Iterator parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -12765,46 +19113,59 @@ public ListStateGet parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.Iterator getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ListStatePutOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStatePut) + public interface KeysOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Keys) com.google.protobuf.MessageOrBuilder { + + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + java.lang.String getIteratorId(); + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + com.google.protobuf.ByteString + getIteratorIdBytes(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Keys} */ - public static final class ListStatePut extends + public static final class Keys extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) - ListStatePutOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Keys) + KeysOrBuilder { private static final long serialVersionUID = 0L; - // Use ListStatePut.newBuilder() to construct. - private ListStatePut(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use Keys.newBuilder() to construct. + private Keys(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private ListStatePut() { + private Keys() { + iteratorId_ = ""; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new ListStatePut(); + return new Keys(); } @java.lang.Override @@ -12814,15 +19175,53 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Keys_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Keys_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder.class); + } + + public static final int ITERATORID_FIELD_NUMBER = 1; + private volatile java.lang.Object iteratorId_; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + @java.lang.Override + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } } private byte memoizedIsInitialized = -1; @@ -12839,6 +19238,9 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, iteratorId_); + } getUnknownFields().writeTo(output); } @@ -12848,6 +19250,9 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, iteratorId_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -12858,11 +19263,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Keys)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) obj; + if (!getIteratorId() + .equals(other.getIteratorId())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -12874,74 +19281,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + ITERATORID_FIELD_NUMBER; + hash = (53 * hash) + getIteratorId().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -12954,7 +19363,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListSt public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Keys prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -12970,26 +19379,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Keys} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Keys) + org.apache.spark.sql.execution.streaming.state.StateMessage.KeysOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Keys_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Keys_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.newBuilder() private Builder() { } @@ -13002,23 +19411,25 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + iteratorId_ = ""; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Keys_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -13026,8 +19437,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Keys result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Keys(this); + result.iteratorId_ = iteratorId_; onBuilt(); return result; } @@ -13066,16 +19478,20 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Keys) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Keys)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Keys other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Keys.getDefaultInstance()) return this; + if (!other.getIteratorId().isEmpty()) { + iteratorId_ = other.iteratorId_; + onChanged(); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -13102,6 +19518,11 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + iteratorId_ = input.readStringRequireUtf8(); + + break; + } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -13117,6 +19538,82 @@ public Builder mergeFrom( } // finally return this; } + + private java.lang.Object iteratorId_ = ""; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string iteratorId = 1; + * @param value The iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + iteratorId_ = value; + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @return This builder for chaining. + */ + public Builder clearIteratorId() { + + iteratorId_ = getDefaultInstance().getIteratorId(); + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @param value The bytes for iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + iteratorId_ = value; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -13130,23 +19627,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Keys) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Keys) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Keys DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Keys(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Keys getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public ListStatePut parsePartialFrom( + public Keys parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -13165,53 +19662,59 @@ public ListStatePut parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.Keys getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface AppendValueOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendValue) + public interface ValuesOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Values) com.google.protobuf.MessageOrBuilder { /** - * bytes value = 1; - * @return The value. + * string iteratorId = 1; + * @return The iteratorId. */ - com.google.protobuf.ByteString getValue(); + java.lang.String getIteratorId(); + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + com.google.protobuf.ByteString + getIteratorIdBytes(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Values} */ - public static final class AppendValue extends + public static final class Values extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) - AppendValueOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Values) + ValuesOrBuilder { private static final long serialVersionUID = 0L; - // Use AppendValue.newBuilder() to construct. - private AppendValue(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use Values.newBuilder() to construct. + private Values(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private AppendValue() { - value_ = com.google.protobuf.ByteString.EMPTY; + private Values() { + iteratorId_ = ""; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new AppendValue(); + return new Values(); } @java.lang.Override @@ -13221,26 +19724,53 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Values_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Values_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Values.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder.class); } - public static final int VALUE_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString value_; + public static final int ITERATORID_FIELD_NUMBER = 1; + private volatile java.lang.Object iteratorId_; /** - * bytes value = 1; - * @return The value. + * string iteratorId = 1; + * @return The iteratorId. */ @java.lang.Override - public com.google.protobuf.ByteString getValue() { - return value_; + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } } private byte memoizedIsInitialized = -1; @@ -13257,8 +19787,8 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { - if (!value_.isEmpty()) { - output.writeBytes(1, value_); + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, iteratorId_); } getUnknownFields().writeTo(output); } @@ -13269,9 +19799,8 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; - if (!value_.isEmpty()) { - size += com.google.protobuf.CodedOutputStream - .computeBytesSize(1, value_); + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, iteratorId_); } size += getUnknownFields().getSerializedSize(); memoizedSize = size; @@ -13283,13 +19812,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Values)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.Values other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Values) obj; - if (!getValue() - .equals(other.getValue())) return false; + if (!getIteratorId() + .equals(other.getIteratorId())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -13301,76 +19830,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + VALUE_FIELD_NUMBER; - hash = (53 * hash) + getValue().hashCode(); + hash = (37 * hash) + ITERATORID_FIELD_NUMBER; + hash = (53 * hash) + getIteratorId().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -13383,7 +19912,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Append public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Values prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -13399,26 +19928,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Values} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Values) + org.apache.spark.sql.execution.streaming.state.StateMessage.ValuesOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Values_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Values_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Values.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Values.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Values.newBuilder() private Builder() { } @@ -13431,7 +19960,7 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - value_ = com.google.protobuf.ByteString.EMPTY; + iteratorId_ = ""; return this; } @@ -13439,17 +19968,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Values_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Values result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -13457,9 +19986,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue b } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(this); - result.value_ = value_; + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Values result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Values(this); + result.iteratorId_ = iteratorId_; onBuilt(); return result; } @@ -13498,18 +20027,19 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Values) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Values)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance()) return this; - if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { - setValue(other.getValue()); + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Values other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Values.getDefaultInstance()) return this; + if (!other.getIteratorId().isEmpty()) { + iteratorId_ = other.iteratorId_; + onChanged(); } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); @@ -13538,7 +20068,7 @@ public Builder mergeFrom( done = true; break; case 10: { - value_ = input.readBytes(); + iteratorId_ = input.readStringRequireUtf8(); break; } // case 10 @@ -13558,36 +20088,78 @@ public Builder mergeFrom( return this; } - private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + private java.lang.Object iteratorId_ = ""; /** - * bytes value = 1; - * @return The value. + * string iteratorId = 1; + * @return The iteratorId. */ - @java.lang.Override - public com.google.protobuf.ByteString getValue() { - return value_; + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } else { + return (java.lang.String) ref; + } } /** - * bytes value = 1; - * @param value The value to set. + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string iteratorId = 1; + * @param value The iteratorId to set. * @return This builder for chaining. */ - public Builder setValue(com.google.protobuf.ByteString value) { + public Builder setIteratorId( + java.lang.String value) { if (value == null) { throw new NullPointerException(); } - value_ = value; + iteratorId_ = value; onChanged(); return this; } /** - * bytes value = 1; + * string iteratorId = 1; * @return This builder for chaining. */ - public Builder clearValue() { + public Builder clearIteratorId() { - value_ = getDefaultInstance().getValue(); + iteratorId_ = getDefaultInstance().getIteratorId(); + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @param value The bytes for iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + iteratorId_ = value; onChanged(); return this; } @@ -13604,23 +20176,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Values) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Values) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Values DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Values(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Values getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public AppendValue parsePartialFrom( + public Values parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -13639,46 +20211,53 @@ public AppendValue parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.Values getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface AppendListOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendList) + public interface RemoveKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveKey) com.google.protobuf.MessageOrBuilder { + + /** + * bytes userKey = 1; + * @return The userKey. + */ + com.google.protobuf.ByteString getUserKey(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveKey} */ - public static final class AppendList extends + public static final class RemoveKey extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendList) - AppendListOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveKey) + RemoveKeyOrBuilder { private static final long serialVersionUID = 0L; - // Use AppendList.newBuilder() to construct. - private AppendList(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use RemoveKey.newBuilder() to construct. + private RemoveKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private AppendList() { + private RemoveKey() { + userKey_ = com.google.protobuf.ByteString.EMPTY; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new AppendList(); + return new RemoveKey(); } @java.lang.Override @@ -13688,15 +20267,26 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder.class); + } + + public static final int USERKEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString userKey_; + /** + * bytes userKey = 1; + * @return The userKey. + */ + @java.lang.Override + public com.google.protobuf.ByteString getUserKey() { + return userKey_; } private byte memoizedIsInitialized = -1; @@ -13713,6 +20303,9 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!userKey_.isEmpty()) { + output.writeBytes(1, userKey_); + } getUnknownFields().writeTo(output); } @@ -13722,6 +20315,10 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!userKey_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, userKey_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -13732,11 +20329,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) obj; + if (!getUserKey() + .equals(other.getUserKey())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -13748,74 +20347,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + USERKEY_FIELD_NUMBER; + hash = (53 * hash) + getUserKey().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -13828,7 +20429,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Append public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -13844,26 +20445,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveKey} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendList) - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKeyOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.newBuilder() private Builder() { } @@ -13876,23 +20477,25 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + userKey_ = com.google.protobuf.ByteString.EMPTY; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -13900,8 +20503,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList bu } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey(this); + result.userKey_ = userKey_; onBuilt(); return result; } @@ -13940,16 +20544,19 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey.getDefaultInstance()) return this; + if (other.getUserKey() != com.google.protobuf.ByteString.EMPTY) { + setUserKey(other.getUserKey()); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -13976,6 +20583,11 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + userKey_ = input.readBytes(); + + break; + } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -13991,6 +20603,40 @@ public Builder mergeFrom( } // finally return this; } + + private com.google.protobuf.ByteString userKey_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes userKey = 1; + * @return The userKey. + */ + @java.lang.Override + public com.google.protobuf.ByteString getUserKey() { + return userKey_; + } + /** + * bytes userKey = 1; + * @param value The userKey to set. + * @return This builder for chaining. + */ + public Builder setUserKey(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + userKey_ = value; + onChanged(); + return this; + } + /** + * bytes userKey = 1; + * @return This builder for chaining. + */ + public Builder clearUserKey() { + + userKey_ = getDefaultInstance().getUserKey(); + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -14004,23 +20650,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendList) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveKey) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendList) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public AppendList parsePartialFrom( + public RemoveKey parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -14039,17 +20685,17 @@ public AppendList parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveKey getDefaultInstanceForType() { return DEFAULT_INSTANCE; } @@ -15071,6 +21717,11 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; private static final @@ -15121,6 +21772,41 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_Keys_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_Keys_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_Values_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_Values_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor; private static final @@ -15162,54 +21848,81 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get "xecution.streaming.state.StateCallComman" + "dH\000\022W\n\013getMapState\030\004 \001(\0132@.org.apache.sp" + "ark.sql.execution.streaming.state.StateC" + - "allCommandH\000B\010\n\006method\"\322\001\n\024StateVariable" + + "allCommandH\000B\010\n\006method\"\250\002\n\024StateVariable" + "Request\022X\n\016valueStateCall\030\001 \001(\0132>.org.ap" + "ache.spark.sql.execution.streaming.state" + ".ValueStateCallH\000\022V\n\rlistStateCall\030\002 \001(\013" + "2=.org.apache.spark.sql.execution.stream" + - "ing.state.ListStateCallH\000B\010\n\006method\"\340\001\n\032" + - "ImplicitGroupingKeyRequest\022X\n\016setImplici" + - "tKey\030\001 \001(\0132>.org.apache.spark.sql.execut" + - "ion.streaming.state.SetImplicitKeyH\000\022^\n\021" + - "removeImplicitKey\030\002 \001(\0132A.org.apache.spa" + - "rk.sql.execution.streaming.state.RemoveI" + - "mplicitKeyH\000B\010\n\006method\"}\n\020StateCallComma" + - "nd\022\021\n\tstateName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n" + - "\003ttl\030\003 \001(\01329.org.apache.spark.sql.execut" + - "ion.streaming.state.TTLConfig\"\341\002\n\016ValueS" + - "tateCall\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 " + - "\001(\01326.org.apache.spark.sql.execution.str" + - "eaming.state.ExistsH\000\022B\n\003get\030\003 \001(\01323.org" + + "ing.state.ListStateCallH\000\022T\n\014mapStateCal" + + "l\030\003 \001(\0132<.org.apache.spark.sql.execution" + + ".streaming.state.MapStateCallH\000B\010\n\006metho" + + "d\"\340\001\n\032ImplicitGroupingKeyRequest\022X\n\016setI" + + "mplicitKey\030\001 \001(\0132>.org.apache.spark.sql." + + "execution.streaming.state.SetImplicitKey" + + "H\000\022^\n\021removeImplicitKey\030\002 \001(\0132A.org.apac" + + "he.spark.sql.execution.streaming.state.R" + + "emoveImplicitKeyH\000B\010\n\006method\"\232\001\n\020StateCa" + + "llCommand\022\021\n\tstateName\030\001 \001(\t\022\016\n\006schema\030\002" + + " \001(\t\022\033\n\023mapStateValueSchema\030\003 \001(\t\022F\n\003ttl" + + "\030\004 \001(\01329.org.apache.spark.sql.execution." + + "streaming.state.TTLConfig\"\341\002\n\016ValueState" + + "Call\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\0132" + + "6.org.apache.spark.sql.execution.streami" + + "ng.state.ExistsH\000\022B\n\003get\030\003 \001(\01323.org.apa" + + "che.spark.sql.execution.streaming.state." + + "GetH\000\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org.ap" + + "ache.spark.sql.execution.streaming.state" + + ".ValueStateUpdateH\000\022F\n\005clear\030\005 \001(\01325.org" + + ".apache.spark.sql.execution.streaming.st" + + "ate.ClearH\000B\010\n\006method\"\220\004\n\rListStateCall\022" + + "\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326.org" + ".apache.spark.sql.execution.streaming.st" + - "ate.GetH\000\022\\\n\020valueStateUpdate\030\004 \001(\0132@.or" + - "g.apache.spark.sql.execution.streaming.s" + - "tate.ValueStateUpdateH\000\022F\n\005clear\030\005 \001(\01325" + - ".org.apache.spark.sql.execution.streamin" + - "g.state.ClearH\000B\010\n\006method\"\220\004\n\rListStateC" + - "all\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326" + - ".org.apache.spark.sql.execution.streamin" + - "g.state.ExistsH\000\022T\n\014listStateGet\030\003 \001(\0132<" + - ".org.apache.spark.sql.execution.streamin" + - "g.state.ListStateGetH\000\022T\n\014listStatePut\030\004" + - " \001(\0132<.org.apache.spark.sql.execution.st" + - "reaming.state.ListStatePutH\000\022R\n\013appendVa" + - "lue\030\005 \001(\0132;.org.apache.spark.sql.executi" + - "on.streaming.state.AppendValueH\000\022P\n\nappe" + - "ndList\030\006 \001(\0132:.org.apache.spark.sql.exec" + - "ution.streaming.state.AppendListH\000\022F\n\005cl" + - "ear\030\007 \001(\01325.org.apache.spark.sql.executi" + - "on.streaming.state.ClearH\000B\010\n\006method\"\035\n\016" + - "SetImplicitKey\022\013\n\003key\030\001 \001(\014\"\023\n\021RemoveImp" + - "licitKey\"\010\n\006Exists\"\005\n\003Get\"!\n\020ValueStateU" + - "pdate\022\r\n\005value\030\001 \001(\014\"\007\n\005Clear\"\"\n\014ListSta" + - "teGet\022\022\n\niteratorId\030\001 \001(\t\"\016\n\014ListStatePu" + - "t\"\034\n\013AppendValue\022\r\n\005value\030\001 \001(\014\"\014\n\nAppen" + - "dList\"\\\n\016SetHandleState\022J\n\005state\030\001 \001(\0162;" + - ".org.apache.spark.sql.execution.streamin" + - "g.state.HandleState\"\037\n\tTTLConfig\022\022\n\ndura" + - "tionMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREATED\020\000" + - "\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020\002\022\n\n" + - "\006CLOSED\020\003b\006proto3" + "ate.ExistsH\000\022T\n\014listStateGet\030\003 \001(\0132<.org" + + ".apache.spark.sql.execution.streaming.st" + + "ate.ListStateGetH\000\022T\n\014listStatePut\030\004 \001(\013" + + "2<.org.apache.spark.sql.execution.stream" + + "ing.state.ListStatePutH\000\022R\n\013appendValue\030" + + "\005 \001(\0132;.org.apache.spark.sql.execution.s" + + "treaming.state.AppendValueH\000\022P\n\nappendLi" + + "st\030\006 \001(\0132:.org.apache.spark.sql.executio" + + "n.streaming.state.AppendListH\000\022F\n\005clear\030" + + "\007 \001(\01325.org.apache.spark.sql.execution.s" + + "treaming.state.ClearH\000B\010\n\006method\"\341\005\n\014Map" + + "StateCall\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002" + + " \001(\01326.org.apache.spark.sql.execution.st" + + "reaming.state.ExistsH\000\022L\n\010getValue\030\003 \001(\013" + + "28.org.apache.spark.sql.execution.stream" + + "ing.state.GetValueH\000\022R\n\013containsKey\030\004 \001(" + + "\0132;.org.apache.spark.sql.execution.strea" + + "ming.state.ContainsKeyH\000\022R\n\013updateValue\030" + + "\005 \001(\0132;.org.apache.spark.sql.execution.s" + + "treaming.state.UpdateValueH\000\022L\n\010iterator" + + "\030\006 \001(\01328.org.apache.spark.sql.execution." + + "streaming.state.IteratorH\000\022D\n\004keys\030\007 \001(\013" + + "24.org.apache.spark.sql.execution.stream" + + "ing.state.KeysH\000\022H\n\006values\030\010 \001(\01326.org.a" + + "pache.spark.sql.execution.streaming.stat" + + "e.ValuesH\000\022N\n\tremoveKey\030\t \001(\01329.org.apac" + + "he.spark.sql.execution.streaming.state.R" + + "emoveKeyH\000\022F\n\005clear\030\n \001(\01325.org.apache.s" + + "park.sql.execution.streaming.state.Clear" + + "H\000B\010\n\006method\"\035\n\016SetImplicitKey\022\013\n\003key\030\001 " + + "\001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006Exists\"\005\n\003Ge" + + "t\"!\n\020ValueStateUpdate\022\r\n\005value\030\001 \001(\014\"\007\n\005" + + "Clear\"\"\n\014ListStateGet\022\022\n\niteratorId\030\001 \001(" + + "\t\"\016\n\014ListStatePut\"\034\n\013AppendValue\022\r\n\005valu" + + "e\030\001 \001(\014\"\014\n\nAppendList\"\033\n\010GetValue\022\017\n\007use" + + "rKey\030\001 \001(\014\"\036\n\013ContainsKey\022\017\n\007userKey\030\001 \001" + + "(\014\"-\n\013UpdateValue\022\017\n\007userKey\030\001 \001(\014\022\r\n\005va" + + "lue\030\002 \001(\014\"\036\n\010Iterator\022\022\n\niteratorId\030\001 \001(" + + "\t\"\032\n\004Keys\022\022\n\niteratorId\030\001 \001(\t\"\034\n\006Values\022" + + "\022\n\niteratorId\030\001 \001(\t\"\034\n\tRemoveKey\022\017\n\007user" + + "Key\030\001 \001(\014\"\\\n\016SetHandleState\022J\n\005state\030\001 \001" + + "(\0162;.org.apache.spark.sql.execution.stre" + + "aming.state.HandleState\"\037\n\tTTLConfig\022\022\n\n" + + "durationMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREAT" + + "ED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020" + + "\002\022\n\n\006CLOSED\020\003b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -15238,7 +21951,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_descriptor, - new java.lang.String[] { "ValueStateCall", "ListStateCall", "Method", }); + new java.lang.String[] { "ValueStateCall", "ListStateCall", "MapStateCall", "Method", }); internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_descriptor = getDescriptor().getMessageTypes().get(4); internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_fieldAccessorTable = new @@ -15250,7 +21963,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_descriptor, - new java.lang.String[] { "StateName", "Schema", "Ttl", }); + new java.lang.String[] { "StateName", "Schema", "MapStateValueSchema", "Ttl", }); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor = getDescriptor().getMessageTypes().get(6); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_fieldAccessorTable = new @@ -15263,74 +21976,122 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor, new java.lang.String[] { "StateName", "Exists", "ListStateGet", "ListStatePut", "AppendValue", "AppendList", "Clear", "Method", }); - internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor = + internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_descriptor = getDescriptor().getMessageTypes().get(8); + internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_MapStateCall_descriptor, + new java.lang.String[] { "StateName", "Exists", "GetValue", "ContainsKey", "UpdateValue", "Iterator", "Keys", "Values", "RemoveKey", "Clear", "Method", }); + internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor = + getDescriptor().getMessageTypes().get(9); internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor, new java.lang.String[] { "Key", }); internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor = - getDescriptor().getMessageTypes().get(9); + getDescriptor().getMessageTypes().get(10); internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor = - getDescriptor().getMessageTypes().get(10); + getDescriptor().getMessageTypes().get(11); internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor = - getDescriptor().getMessageTypes().get(11); + getDescriptor().getMessageTypes().get(12); internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor = - getDescriptor().getMessageTypes().get(12); + getDescriptor().getMessageTypes().get(13); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor, new java.lang.String[] { "Value", }); internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor = - getDescriptor().getMessageTypes().get(13); + getDescriptor().getMessageTypes().get(14); internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor = - getDescriptor().getMessageTypes().get(14); + getDescriptor().getMessageTypes().get(15); internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor, new java.lang.String[] { "IteratorId", }); internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor = - getDescriptor().getMessageTypes().get(15); + getDescriptor().getMessageTypes().get(16); internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor = - getDescriptor().getMessageTypes().get(16); + getDescriptor().getMessageTypes().get(17); internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor, new java.lang.String[] { "Value", }); internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor = - getDescriptor().getMessageTypes().get(17); + getDescriptor().getMessageTypes().get(18); internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor, new java.lang.String[] { }); + internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_descriptor = + getDescriptor().getMessageTypes().get(19); + internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_GetValue_descriptor, + new java.lang.String[] { "UserKey", }); + internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_descriptor = + getDescriptor().getMessageTypes().get(20); + internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ContainsKey_descriptor, + new java.lang.String[] { "UserKey", }); + internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_descriptor = + getDescriptor().getMessageTypes().get(21); + internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_UpdateValue_descriptor, + new java.lang.String[] { "UserKey", "Value", }); + internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_descriptor = + getDescriptor().getMessageTypes().get(22); + internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_Iterator_descriptor, + new java.lang.String[] { "IteratorId", }); + internal_static_org_apache_spark_sql_execution_streaming_state_Keys_descriptor = + getDescriptor().getMessageTypes().get(23); + internal_static_org_apache_spark_sql_execution_streaming_state_Keys_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_Keys_descriptor, + new java.lang.String[] { "IteratorId", }); + internal_static_org_apache_spark_sql_execution_streaming_state_Values_descriptor = + getDescriptor().getMessageTypes().get(24); + internal_static_org_apache_spark_sql_execution_streaming_state_Values_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_Values_descriptor, + new java.lang.String[] { "IteratorId", }); + internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_descriptor = + getDescriptor().getMessageTypes().get(25); + internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_RemoveKey_descriptor, + new java.lang.String[] { "UserKey", }); internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor = - getDescriptor().getMessageTypes().get(18); + getDescriptor().getMessageTypes().get(26); internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor, new java.lang.String[] { "State", }); internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor = - getDescriptor().getMessageTypes().get(19); + getDescriptor().getMessageTypes().get(27); internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index fed1843acfa56..3aed0b2463f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -30,12 +30,15 @@ import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState, StateVariableType} -import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} -import org.apache.spark.sql.streaming.{ListState, TTLConfig, ValueState} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, ValueState} +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils /** * This class is used to handle the state requests from the Python side. It runs on a separate @@ -60,7 +63,9 @@ class TransformWithStateInPandasStateServer( deserializerForTest: TransformWithStateInPandasDeserializer = null, arrowStreamWriterForTest: BaseStreamingArrowWriter = null, listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null, - listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null) + iteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null, + mapStatesMapForTest : mutable.HashMap[String, MapStateInfo] = null, + keyValueIteratorMapForTest: mutable.HashMap[String, Iterator[(Row, Row)]] = null) extends Runnable with Logging { private val keyRowDeserializer: ExpressionEncoder.Deserializer[Row] = ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer() @@ -79,14 +84,31 @@ class TransformWithStateInPandasStateServer( } else { new mutable.HashMap[String, ListStateInfo]() } - // A map to store the iterator id -> iterator mapping. This is to keep track of the - // current iterator position for each list state in a grouping key in case user tries to fetch - // another list state before the current iterator is exhausted. - private var listStateIterators = if (listStateIteratorMapForTest != null) { - listStateIteratorMapForTest + // A map to store the iterator id -> Iterator[Row] mapping. This is to keep track of the + // current iterator position for each iterator id in a state variable for a grouping key in case + // user tries to fetch another state variable before the current iterator is exhausted. This is + // mainly used for list state and map state. + private var iterators = if (iteratorMapForTest != null) { + iteratorMapForTest } else { new mutable.HashMap[String, Iterator[Row]]() } + // A map to store the map state name -> (map state, schema, map state row deserializer, + // map state row serializer) mapping. + private val mapStates = if (mapStatesMapForTest != null) { + mapStatesMapForTest + } else { + new mutable.HashMap[String, MapStateInfo]() + } + + // A map to store the iterator id -> Iterator[(Row, Row)] mapping. This is to keep track of the + // current key-value iterator position for each iterator id in a map state for a grouping key in + // case user tries to fetch another state variable before the current iterator is exhausted. + private var keyValueIterators = if (keyValueIteratorMapForTest != null) { + keyValueIteratorMapForTest + } else { + new mutable.HashMap[String, Iterator[(Row, Row)]]() + } def run(): Unit = { val listeningSocket = stateServerSocket.accept() @@ -149,13 +171,13 @@ class TransformWithStateInPandasStateServer( // The key row is serialized as a byte array, we need to convert it back to a Row val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, keyRowDeserializer) ImplicitGroupingKeyTracker.setImplicitKey(keyRow) - // Reset the list state iterators for a new grouping key. - listStateIterators = new mutable.HashMap[String, Iterator[Row]]() + // Reset the list/map state iterators for a new grouping key. + iterators = new mutable.HashMap[String, Iterator[Row]]() sendResponse(0) case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => ImplicitGroupingKeyTracker.removeImplicitKey() - // Reset the list state iterators for a new grouping key. - listStateIterators = new mutable.HashMap[String, Iterator[Row]]() + // Reset the list/map state iterators for a new grouping key. + iterators = new mutable.HashMap[String, Iterator[Row]]() sendResponse(0) case _ => throw new IllegalArgumentException("Invalid method call") @@ -195,6 +217,15 @@ class TransformWithStateInPandasStateServer( None } initializeStateVariable(stateName, schema, StateVariableType.ListState, ttlDurationMs) + case StatefulProcessorCall.MethodCase.GETMAPSTATE => + val stateName = message.getGetMapState.getStateName + val userKeySchema = message.getGetMapState.getSchema + val valueSchema = message.getGetMapState.getMapStateValueSchema + val ttlDurationMs = if (message.getGetMapState.hasTtl) { + Some(message.getGetMapState.getTtl.getDurationMs) + } else None + initializeStateVariable(stateName, userKeySchema, StateVariableType.MapState, ttlDurationMs, + valueSchema) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -206,6 +237,8 @@ class TransformWithStateInPandasStateServer( handleValueStateRequest(message.getValueStateCall) case StateVariableRequest.MethodCase.LISTSTATECALL => handleListStateRequest(message.getListStateCall) + case StateVariableRequest.MethodCase.MAPSTATECALL => + handleMapStateRequest(message.getMapStateCall) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -281,10 +314,10 @@ class TransformWithStateInPandasStateServer( sendResponse(0) case ListStateCall.MethodCase.LISTSTATEGET => val iteratorId = message.getListStateGet.getIteratorId - var iteratorOption = listStateIterators.get(iteratorId) + var iteratorOption = iterators.get(iteratorId) if (iteratorOption.isEmpty) { iteratorOption = Some(listStateInfo.listState.get()) - listStateIterators.put(iteratorId, iteratorOption.get) + iterators.put(iteratorId, iteratorOption.get) } if (!iteratorOption.get.hasNext) { sendResponse(2, s"List state $stateName doesn't contain any value.") @@ -292,32 +325,8 @@ class TransformWithStateInPandasStateServer( } else { sendResponse(0) } - outputStream.flush() - val arrowStreamWriter = if (arrowStreamWriterForTest != null) { - arrowStreamWriterForTest - } else { - val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema, timeZoneId, - errorOnDuplicatedFieldNames, largeVarTypes) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for transformWithStateInPandas state socket", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, outputStream), - arrowTransformWithStateInPandasMaxRecordsPerBatch) - } - val listRowSerializer = listStateInfo.serializer - // Only write a single batch in each GET request. Stops writing row if rowCount reaches - // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case - // when there are multiple state variables, user tries to access a different state variable - // while the current state variable is not exhausted yet. - var rowCount = 0 - while (iteratorOption.get.hasNext && - rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) { - val row = iteratorOption.get.next() - val internalRow = listRowSerializer(row) - arrowStreamWriter.writeRow(internalRow) - rowCount += 1 - } - arrowStreamWriter.finalizeCurrentArrowBatch() + sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema, + arrowStreamWriterForTest) { data => listStateInfo.serializer(data)} case ListStateCall.MethodCase.APPENDVALUE => val byteArray = message.getAppendValue.getValue.toByteArray val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema, @@ -336,6 +345,124 @@ class TransformWithStateInPandasStateServer( } } + private[sql] def handleMapStateRequest(message: MapStateCall): Unit = { + val stateName = message.getStateName + if (!mapStates.contains(stateName)) { + logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} is not initialized.") + sendResponse(1, s"Map state $stateName is not initialized.") + return + } + val mapStateInfo = mapStates(stateName) + message.getMethodCase match { + case MapStateCall.MethodCase.EXISTS => + if (mapStateInfo.mapState.exists()) { + sendResponse(0) + } else { + // Send status code 2 to indicate that the list state doesn't have a value yet. + sendResponse(2, s"state $stateName doesn't exist") + } + case MapStateCall.MethodCase.GETVALUE => + val keyBytes = message.getGetValue.getUserKey.toByteArray + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema, + mapStateInfo.keyDeserializer) + val value = mapStateInfo.mapState.getValue(keyRow) + if (value != null) { + val valueBytes = PythonSQLUtils.toPyRow(value) + val byteString = ByteString.copyFrom(valueBytes) + sendResponse(0, null, byteString) + } else { + logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't contain" + + log" key ${MDC(LogKeys.KEY, keyRow.toString)}.") + sendResponse(0) + } + case MapStateCall.MethodCase.CONTAINSKEY => + val keyBytes = message.getContainsKey.getUserKey.toByteArray + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema, + mapStateInfo.keyDeserializer) + if (mapStateInfo.mapState.containsKey(keyRow)) { + sendResponse(0) + } else { + sendResponse(2, s"Map state $stateName doesn't contain key ${keyRow.toString()}") + } + case MapStateCall.MethodCase.UPDATEVALUE => + val keyBytes = message.getUpdateValue.getUserKey.toByteArray + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema, + mapStateInfo.keyDeserializer) + val valueBytes = message.getUpdateValue.getValue.toByteArray + val valueRow = PythonSQLUtils.toJVMRow(valueBytes, mapStateInfo.valueSchema, + mapStateInfo.valueDeserializer) + mapStateInfo.mapState.updateValue(keyRow, valueRow) + sendResponse(0) + case MapStateCall.MethodCase.ITERATOR => + val iteratorId = message.getIterator.getIteratorId + var iteratorOption = keyValueIterators.get(iteratorId) + if (iteratorOption.isEmpty) { + iteratorOption = Some(mapStateInfo.mapState.iterator()) + keyValueIterators.put(iteratorId, iteratorOption.get) + } + if (!iteratorOption.get.hasNext) { + sendResponse(2, s"Map state $stateName doesn't contain any entry.") + } else { + sendResponse(0) + val keyValueStateSchema: StructType = StructType( + Array( + // key row serialized as a byte array. + StructField("keyRow", BinaryType), + // value row serialized as a byte array. + StructField("valueRow", BinaryType) + ) + ) + sendIteratorAsArrowBatches(iteratorOption.get, keyValueStateSchema, + arrowStreamWriterForTest) {tuple => + val keyBytes = PythonSQLUtils.toPyRow(tuple._1) + val valueBytes = PythonSQLUtils.toPyRow(tuple._2) + new GenericInternalRow( + Array[Any]( + keyBytes, + valueBytes + ) + ) + } + } + case MapStateCall.MethodCase.KEYS => + val iteratorId = message.getKeys.getIteratorId + var iteratorOption = iterators.get(iteratorId) + if (iteratorOption.isEmpty) { + iteratorOption = Some(mapStateInfo.mapState.keys()) + iterators.put(iteratorId, iteratorOption.get) + } + if (!iteratorOption.get.hasNext) { + sendResponse(2, s"Map state $stateName doesn't contain any key.") + } else { + sendResponse(0) + sendIteratorAsArrowBatches(iteratorOption.get, mapStateInfo.keySchema, + arrowStreamWriterForTest) {data => mapStateInfo.keySerializer(data)} + } + case MapStateCall.MethodCase.VALUES => + val iteratorId = message.getValues.getIteratorId + var iteratorOption = iterators.get(iteratorId) + if (iteratorOption.isEmpty) { + iteratorOption = Some(mapStateInfo.mapState.values()) + iterators.put(iteratorId, iteratorOption.get) + } + if (!iteratorOption.get.hasNext) { + sendResponse(2, s"Map state $stateName doesn't contain any value.") + } else { + sendResponse(0) + sendIteratorAsArrowBatches(iteratorOption.get, mapStateInfo.valueSchema, + arrowStreamWriterForTest) {data => mapStateInfo.valueSerializer(data)} + } + case MapStateCall.MethodCase.REMOVEKEY => + val keyBytes = message.getRemoveKey.getUserKey.toByteArray + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema, + mapStateInfo.keyDeserializer) + mapStateInfo.mapState.removeKey(keyRow) + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + private def sendResponse( status: Int, errorMessage: String = null, @@ -358,7 +485,8 @@ class TransformWithStateInPandasStateServer( stateName: String, schemaString: String, stateType: StateVariableType.StateVariableType, - ttlDurationMs: Option[Int]): Unit = { + ttlDurationMs: Option[Int], + mapStateValueSchemaString: String = null): Unit = { val schema = StructType.fromString(schemaString) val expressionEncoder = ExpressionEncoder(schema).resolveAndBind() stateType match { @@ -389,6 +517,63 @@ class TransformWithStateInPandasStateServer( } else { sendResponse(1, s"List state $stateName already exists") } + case StateVariableType.MapState => if (!mapStates.contains(stateName)) { + val valueSchema = StructType.fromString(mapStateValueSchemaString) + val valueExpressionEncoder = ExpressionEncoder(valueSchema).resolveAndBind() + val state = if (ttlDurationMs.isEmpty) { + statefulProcessorHandle.getMapState[Row, Row](stateName, + Encoders.row(schema), Encoders.row(valueSchema)) + } else { + statefulProcessorHandle.getMapState[Row, Row](stateName, Encoders.row(schema), + Encoders.row(valueSchema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) + } + mapStates.put(stateName, + MapStateInfo(state, schema, valueSchema, expressionEncoder.createDeserializer(), + expressionEncoder.createSerializer(), valueExpressionEncoder.createDeserializer(), + valueExpressionEncoder.createSerializer())) + sendResponse(0) + } else { + sendResponse(1, s"Map state $stateName already exists") + } + } + } + + private def sendIteratorAsArrowBatches[T]( + iter: Iterator[T], + outputSchema: StructType, + arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T => InternalRow): Unit = { + outputStream.flush() + val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId, + errorOnDuplicatedFieldNames, largeVarTypes) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for transformWithStateInPandas state socket", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = new ArrowStreamWriter(root, null, outputStream) + val arrowStreamWriter = if (arrowStreamWriterForTest != null) { + arrowStreamWriterForTest + } else { + new BaseStreamingArrowWriter(root, writer, arrowTransformWithStateInPandasMaxRecordsPerBatch) + } + // Only write a single batch in each GET request. Stops writing row if rowCount reaches + // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case + // when there are multiple state variables, user tries to access a different state variable + // while the current state variable is not exhausted yet. + var rowCount = 0 + while (iter.hasNext && rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) { + val data = iter.next() + val internalRow = func(data) + arrowStreamWriter.writeRow(internalRow) + rowCount += 1 + } + arrowStreamWriter.finalizeCurrentArrowBatch() + Utils.tryWithSafeFinally { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + root.close() + allocator.close() } } } @@ -409,3 +594,15 @@ case class ListStateInfo( schema: StructType, deserializer: ExpressionEncoder.Deserializer[Row], serializer: ExpressionEncoder.Serializer[Row]) + +/** + * Case class to store the information of a map state. + */ +case class MapStateInfo( + mapState: MapState[Row, Row], + keySchema: StructType, + valueSchema: StructType, + keyDeserializer: ExpressionEncoder.Deserializer[Row], + keySerializer: ExpressionEncoder.Serializer[Row], + valueDeserializer: ExpressionEncoder.Deserializer[Row], + valueSerializer: ExpressionEncoder.Serializer[Row]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index 776772bb51ca8..2a728dc81d0b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.execution.streaming.state.StateMessage -import org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, AppendValue, Clear, Exists, Get, HandleState, ListStateCall, ListStateGet, ListStatePut, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall, ValueStateUpdate} -import org.apache.spark.sql.streaming.{ListState, TTLConfig, ValueState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, AppendValue, Clear, ContainsKey, Exists, Get, GetValue, HandleState, Keys, ListStateCall, ListStateGet, ListStatePut, MapStateCall, RemoveKey, SetHandleState, StateCallCommand, StatefulProcessorCall, UpdateValue, Values, ValueStateCall, ValueStateUpdate} +import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, ValueState} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -53,6 +53,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo var outputStream: DataOutputStream = _ var valueState: ValueState[Row] = _ var listState: ListState[Row] = _ + var mapState: MapState[Row, Row] = _ var stateServer: TransformWithStateInPandasStateServer = _ var stateDeserializer: ExpressionEncoder.Deserializer[Row] = _ var stateSerializer: ExpressionEncoder.Serializer[Row] = _ @@ -60,31 +61,39 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo var arrowStreamWriter: BaseStreamingArrowWriter = _ var valueStateMap: mutable.HashMap[String, ValueStateInfo] = mutable.HashMap() var listStateMap: mutable.HashMap[String, ListStateInfo] = mutable.HashMap() + var mapStateMap: mutable.HashMap[String, MapStateInfo] = mutable.HashMap() override def beforeEach(): Unit = { statefulProcessorHandle = mock(classOf[StatefulProcessorHandleImpl]) outputStream = mock(classOf[DataOutputStream]) valueState = mock(classOf[ValueState[Row]]) listState = mock(classOf[ListState[Row]]) + mapState = mock(classOf[MapState[Row, Row]]) stateDeserializer = ExpressionEncoder(stateSchema).resolveAndBind().createDeserializer() stateSerializer = ExpressionEncoder(stateSchema).resolveAndBind().createSerializer() valueStateMap = mutable.HashMap[String, ValueStateInfo](stateName -> ValueStateInfo(valueState, stateSchema, stateDeserializer)) listStateMap = mutable.HashMap[String, ListStateInfo](stateName -> ListStateInfo(listState, stateSchema, stateDeserializer, stateSerializer)) - // Iterator map for list state. Please note that `handleImplicitGroupingKeyRequest` would + mapStateMap = mutable.HashMap[String, MapStateInfo](stateName -> + MapStateInfo(mapState, stateSchema, stateSchema, stateDeserializer, + stateSerializer, stateDeserializer, stateSerializer)) + + // Iterator map for list/map state. Please note that `handleImplicitGroupingKeyRequest` would // reset the iterator map to empty so be careful to call it if you want to access the iterator // map later. - val listStateIteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> - Iterator(new GenericRowWithSchema(Array(1), stateSchema))) + val testRow = getIntegerRow(1) + val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> Iterator(testRow)) + val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row, Row)]](iteratorId -> + Iterator((testRow, testRow))) transformWithStateInPandasDeserializer = mock(classOf[TransformWithStateInPandasDeserializer]) arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter]) stateServer = new TransformWithStateInPandasStateServer(serverSocket, statefulProcessorHandle, groupingKeySchema, "", false, false, 2, outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, - listStateMap, listStateIteratorMap) + listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap) when(transformWithStateInPandasDeserializer.readArrowBatches(any)) - .thenReturn(Seq(new GenericRowWithSchema(Array(1), stateSchema))) + .thenReturn(Seq(getIntegerRow(1))) } test("set handle state") { @@ -141,6 +150,31 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } } + Seq(true, false).foreach { useTTL => + test(s"get map state, useTTL=$useTTL") { + val stateCallCommandBuilder = StateCallCommand.newBuilder() + .setStateName("newName") + .setSchema("StructType(List(StructField(value,IntegerType,true)))") + .setMapStateValueSchema("StructType(List(StructField(value,IntegerType,true)))") + if (useTTL) { + stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000)) + } + val message = StatefulProcessorCall + .newBuilder() + .setGetMapState(stateCallCommandBuilder.build()) + .build() + stateServer.handleStatefulProcessorCall(message) + if (useTTL) { + verify(statefulProcessorHandle) + .getMapState[Row, Row](any[String], any[Encoder[Row]], any[Encoder[Row]], any[TTLConfig]) + } else { + verify(statefulProcessorHandle).getMapState[Row, Row](any[String], any[Encoder[Row]], + any[Encoder[Row]]) + } + verify(outputStream).writeInt(0) + } + } + test("value state exists") { val message = ValueStateCall.newBuilder().setStateName(stateName) .setExists(Exists.newBuilder().build()).build() @@ -152,7 +186,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = ValueStateCall.newBuilder().setStateName(stateName) .setGet(Get.newBuilder().build()).build() val schema = new StructType().add("value", "int") - when(valueState.getOption()).thenReturn(Some(new GenericRowWithSchema(Array(1), schema))) + when(valueState.getOption()).thenReturn(Some(getIntegerRow(1))) stateServer.handleValueStateRequest(message) verify(valueState).getOption() verify(outputStream).writeInt(argThat((x: Int) => x > 0)) @@ -214,10 +248,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = ListStateCall.newBuilder().setStateName(stateName) .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> - Iterator(new GenericRowWithSchema(Array(1), stateSchema), - new GenericRowWithSchema(Array(2), stateSchema), - new GenericRowWithSchema(Array(3), stateSchema), - new GenericRowWithSchema(Array(4), stateSchema))) + Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3), getIntegerRow(4))) stateServer = new TransformWithStateInPandasStateServer(serverSocket, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, outputStream, valueStateMap, @@ -245,9 +276,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, iteratorMap) - when(listState.get()).thenReturn(Iterator(new GenericRowWithSchema(Array(1), stateSchema), - new GenericRowWithSchema(Array(2), stateSchema), - new GenericRowWithSchema(Array(3), stateSchema))) + when(listState.get()).thenReturn(Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3))) stateServer.handleListStateRequest(message) verify(listState).get() // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still @@ -279,4 +308,159 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo verify(transformWithStateInPandasDeserializer).readArrowBatches(any) verify(listState).appendList(any) } + + test("map state exists") { + val message = MapStateCall.newBuilder().setStateName(stateName) + .setExists(Exists.newBuilder().build()).build() + stateServer.handleMapStateRequest(message) + verify(mapState).exists() + } + + test("map state get") { + val byteString: ByteString = ByteString.copyFrom(byteArray) + val message = MapStateCall.newBuilder().setStateName(stateName) + .setGetValue(GetValue.newBuilder().setUserKey(byteString).build()).build() + val schema = new StructType().add("value", "int") + when(mapState.getValue(any[Row])).thenReturn(getIntegerRow(1)) + stateServer.handleMapStateRequest(message) + verify(mapState).getValue(any[Row]) + verify(outputStream).writeInt(argThat((x: Int) => x > 0)) + } + + test("map state contains key") { + val byteString: ByteString = ByteString.copyFrom(byteArray) + val message = MapStateCall.newBuilder().setStateName(stateName) + .setContainsKey(ContainsKey.newBuilder().setUserKey(byteString).build()).build() + stateServer.handleMapStateRequest(message) + verify(mapState).containsKey(any[Row]) + } + + test("map state update value") { + val byteString: ByteString = ByteString.copyFrom(byteArray) + val message = MapStateCall.newBuilder().setStateName(stateName) + .setUpdateValue(UpdateValue.newBuilder().setUserKey(byteString).setValue(byteString).build()) + .build() + stateServer.handleMapStateRequest(message) + verify(mapState).updateValue(any[Row], any[Row]) + } + + test("map state iterator - iterator in map") { + val message = MapStateCall.newBuilder().setStateName(stateName) + .setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build() + stateServer.handleMapStateRequest(message) + verify(mapState, times(0)).iterator() + verify(arrowStreamWriter).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("map state iterator - iterator in map with multiple batches") { + val maxRecordsPerBatch = 2 + val message = MapStateCall.newBuilder().setStateName(stateName) + .setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build() + val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row, Row)]](iteratorId -> + Iterator((getIntegerRow(1), getIntegerRow(1)), (getIntegerRow(2), getIntegerRow(2)), + (getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4), getIntegerRow(4)))) + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, transformWithStateInPandasDeserializer, + arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap) + // First call should send 2 records. + stateServer.handleMapStateRequest(message) + verify(mapState, times(0)).iterator() + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + // Second call should send the remaining 2 records. + stateServer.handleMapStateRequest(message) + verify(mapState, times(0)).iterator() + // Since Mockito's verify counts the total number of calls, the expected number of writeRow call + // should be 2 * maxRecordsPerBatch. + verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch() + } + + test("map state iterator - iterator not in map") { + val maxRecordsPerBatch = 2 + val message = MapStateCall.newBuilder().setStateName(stateName) + .setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build() + val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] = mutable.HashMap() + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, transformWithStateInPandasDeserializer, + arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap) + when(mapState.iterator()).thenReturn(Iterator((getIntegerRow(1), getIntegerRow(1)), + (getIntegerRow(2), getIntegerRow(2)), (getIntegerRow(3), getIntegerRow(3)))) + stateServer.handleMapStateRequest(message) + verify(mapState).iterator() + // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still + // having 1 row left in the iterator. + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("map state keys - iterator in map") { + val message = MapStateCall.newBuilder().setStateName(stateName) + .setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build() + stateServer.handleMapStateRequest(message) + verify(mapState, times(0)).keys() + verify(arrowStreamWriter).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("map state keys - iterator not in map") { + val maxRecordsPerBatch = 2 + val message = MapStateCall.newBuilder().setStateName(stateName) + .setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build() + val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, transformWithStateInPandasDeserializer, + arrowStreamWriter, listStateMap, iteratorMap, mapStateMap) + when(mapState.keys()).thenReturn(Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3))) + stateServer.handleMapStateRequest(message) + verify(mapState).keys() + // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still + // having 1 row left in the iterator. + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("map state values - iterator in map") { + val message = MapStateCall.newBuilder().setStateName(stateName) + .setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build() + stateServer.handleMapStateRequest(message) + verify(mapState, times(0)).values() + verify(arrowStreamWriter).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("map state values - iterator not in map") { + val maxRecordsPerBatch = 2 + val message = MapStateCall.newBuilder().setStateName(stateName) + .setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build() + val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, transformWithStateInPandasDeserializer, + arrowStreamWriter, listStateMap, iteratorMap, mapStateMap) + when(mapState.values()).thenReturn(Iterator(getIntegerRow(1), getIntegerRow(2), + getIntegerRow(3))) + stateServer.handleMapStateRequest(message) + verify(mapState).values() + // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still + // having 1 row left in the iterator. + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("remove key") { + val byteString: ByteString = ByteString.copyFrom(byteArray) + val message = MapStateCall.newBuilder().setStateName(stateName) + .setRemoveKey(RemoveKey.newBuilder().setUserKey(byteString).build()).build() + stateServer.handleMapStateRequest(message) + verify(mapState).removeKey(any[Row]) + } + + private def getIntegerRow(value: Int): Row = { + new GenericRowWithSchema(Array(value), stateSchema) + } } From 4ed908e96af578361879515cd71e8e03f3c8e938 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 21 Oct 2024 09:13:44 +0200 Subject: [PATCH 077/108] [SPARK-50048][SQL] Assign appropriate error condition for `_LEGACY_ERROR_TEMP_2114`: `UNRECOGNIZED_STATISTIC` ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for `_LEGACY_ERROR_TEMP_2114`: `UNRECOGNIZED_STATISTIC` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48572 from itholic/LEGACY_2114. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 11 ++++++----- .../spark/sql/errors/QueryExecutionErrors.scala | 4 ++-- .../org/apache/spark/sql/DataFrameStatSuite.scala | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 28b2bc9a857db..00f288bfc65c3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4715,6 +4715,12 @@ ], "sqlState" : "42704" }, + "UNRECOGNIZED_STATISTIC" : { + "message" : [ + "The statistic is not recognized. Valid statistics include `count`, `count_distinct`, `approx_count_distinct`, `mean`, `stddev`, `min`, `max`, and percentile values." + ], + "sqlState" : "42704" + }, "UNRESOLVABLE_TABLE_VALUED_FUNCTION" : { "message" : [ "Could not resolve to a table-valued function.", @@ -7150,11 +7156,6 @@ "Unable to parse as a percentile." ] }, - "_LEGACY_ERROR_TEMP_2114" : { - "message" : [ - " is not a recognised statistic." - ] - }, "_LEGACY_ERROR_TEMP_2115" : { "message" : [ "Unknown column: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 26ed25ba90167..946d129c47c4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1141,8 +1141,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def statisticNotRecognizedError(stats: String): SparkIllegalArgumentException = { new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_2114", - messageParameters = Map("stats" -> stats)) + errorClass = "UNRECOGNIZED_STATISTIC", + messageParameters = Map("stats" -> toSQLId(stats))) } def unknownColumnError(unknownColumn: String): SparkIllegalArgumentException = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 2f7b072fb7ece..37319de0b6624 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -588,8 +588,8 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkIllegalArgumentException] { person2.summary("foo") }, - condition = "_LEGACY_ERROR_TEMP_2114", - parameters = Map("stats" -> "foo") + condition = "UNRECOGNIZED_STATISTIC", + parameters = Map("stats" -> "`foo`") ) checkError( From 26325440a2b7094cb91e709ef2cd2b7bae7246e1 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 21 Oct 2024 19:42:05 +0900 Subject: [PATCH 078/108] [SPARK-49851][CONNECT][PYTHON] API compatibility check for Protobuf ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Protobuf functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48576 from itholic/compat_protobuf. Authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- .../sql/tests/test_connect_compatibility.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 193de8e1b6b6a..c5e66fde018c5 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -33,6 +33,7 @@ import pyspark.sql.functions as ClassicFunctions from pyspark.sql.group import GroupedData as ClassicGroupedData import pyspark.sql.avro.functions as ClassicAvro +import pyspark.sql.protobuf.functions as ClassicProtobuf if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -47,6 +48,7 @@ import pyspark.sql.connect.functions as ConnectFunctions from pyspark.sql.connect.group import GroupedData as ConnectGroupedData import pyspark.sql.connect.avro.functions as ConnectAvro + import pyspark.sql.connect.protobuf.functions as ConnectProtobuf class ConnectCompatibilityTestsMixin: @@ -399,6 +401,28 @@ def test_avro_compatibility(self): expected_missing_classic_methods, ) + def test_protobuf_compatibility(self): + """Test Protobuf compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + # The current supported Avro functions are only `from_protobuf` and `to_protobuf`. + # The missing methods belows are just util functions that imported to implement them. + expected_missing_connect_methods = { + "cast", + "try_remote_protobuf_functions", + "get_active_spark_context", + } + expected_missing_classic_methods = {"lit", "check_dependencies"} + self.check_compatibility( + ClassicProtobuf, + ConnectProtobuf, + "Protobuf", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From 738dfa346bde59d10fa18bc662b858b0c3354254 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Mon, 21 Oct 2024 20:11:04 +0900 Subject: [PATCH 079/108] [SPARK-50035][SS] Add support for explicit handleExpiredTimer function part of the stateful processor ### What changes were proposed in this pull request? Add support for explicit handleExpiredTimer function part of the stateful processor ### Why are the changes needed? Separate function will provide for cleaner UX and eliminates need for special handling of empty input rows and expired timer validity ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Existing unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48553 from anishshri-db/task/SPARK-50035. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../sql/streaming/ExpiredTimerInfo.scala | 8 +- .../sql/streaming/StatefulProcessor.scala | 21 +- .../streaming/ExpiredTimerInfoImpl.scala | 10 +- .../streaming/TransformWithStateExec.scala | 8 +- .../spark/sql/TestStatefulProcessor.java | 61 +++--- ...TestStatefulProcessorWithInitialState.java | 33 ++-- ...ateDataSourceTransformWithStateSuite.scala | 14 +- .../TransformWithListStateSuite.scala | 6 +- .../TransformWithListStateTTLSuite.scala | 5 +- .../TransformWithMapStateSuite.scala | 5 +- .../TransformWithMapStateTTLSuite.scala | 10 +- .../TransformWithStateChainingSuite.scala | 9 +- .../TransformWithStateInitialStateSuite.scala | 54 +++-- .../streaming/TransformWithStateSuite.scala | 185 ++++++++---------- .../TransformWithValueStateTTLSuite.scala | 6 +- 15 files changed, 195 insertions(+), 240 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala index a0958aceb3b3a..c5d3adda8b87e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala @@ -22,18 +22,12 @@ import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} /** - * Class used to provide access to expired timer's expiry time. These values are only relevant if - * the ExpiredTimerInfo is valid. + * Class used to provide access to expired timer's expiry time. */ @Experimental @Evolving private[sql] trait ExpiredTimerInfo extends Serializable { - /** - * Check if provided ExpiredTimerInfo is valid. - */ - def isValid(): Boolean - /** * Get the expired timer's expiry time as milliseconds in epoch time. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index d2c6010454c55..719d1e572c20d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -56,16 +56,27 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { * @param timerValues * \- instance of TimerValues that provides access to current processing/event time if * available - * @param expiredTimerInfo - * \- instance of ExpiredTimerInfo that provides access to expired timer if applicable * @return * \- Zero or more output rows */ - def handleInputRows( + def handleInputRows(key: K, inputRows: Iterator[I], timerValues: TimerValues): Iterator[O] + + /** + * Function that will be invoked when a timer is fired for a given key. Users can choose to + * evict state, register new timers and optionally provide output rows. + * @param key + * \- grouping key + * @param timerValues + * \- instance of TimerValues that provides access to current processing/event + * @param expiredTimerInfo + * \- instance of ExpiredTimerInfo that provides access to expired timer + * @return + * Zero or more output rows + */ + def handleExpiredTimer( key: K, - inputRows: Iterator[I], timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[O] + expiredTimerInfo: ExpiredTimerInfo): Iterator[O] = Iterator.empty /** * Function called as the last method that allows for users to perform any cleanup or teardown diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala index e0bfc684585df..984d650a27ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala @@ -16,21 +16,15 @@ */ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, TimeMode} +import org.apache.spark.sql.streaming.ExpiredTimerInfo /** * Class that provides a concrete implementation that can be used to provide access to expired * timer's expiry time. These values are only relevant if the ExpiredTimerInfo * is valid. - * @param isValid - boolean to check if the provided ExpiredTimerInfo is valid * @param expiryTimeInMsOpt - option to expired timer's expiry time as milliseconds in epoch time */ -class ExpiredTimerInfoImpl( - isValid: Boolean, - expiryTimeInMsOpt: Option[Long] = None, - timeMode: TimeMode = TimeMode.None()) extends ExpiredTimerInfo { - - override def isValid(): Boolean = isValid +class ExpiredTimerInfoImpl(expiryTimeInMsOpt: Option[Long] = None) extends ExpiredTimerInfo { override def getExpiryTimeInMs(): Long = expiryTimeInMsOpt.getOrElse(-1L) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index cd567f4c74d70..42cd429587f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -250,8 +250,7 @@ case class TransformWithStateExec( val mappedIterator = statefulProcessor.handleInputRows( keyObj, valueObjIter, - new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction), - new ExpiredTimerInfoImpl(isValid = false)).map { obj => + new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction)).map { obj => getOutputRow(obj) } ImplicitGroupingKeyTracker.removeImplicitKey() @@ -301,11 +300,10 @@ case class TransformWithStateExec( processorHandle: StatefulProcessorHandleImpl): Iterator[InternalRow] = { val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) ImplicitGroupingKeyTracker.setImplicitKey(keyObj) - val mappedIterator = statefulProcessor.handleInputRows( + val mappedIterator = statefulProcessor.handleExpiredTimer( keyObj, - Iterator.empty, new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction), - new ExpiredTimerInfoImpl(isValid = true, Some(expiryTimestampMs))).map { obj => + new ExpiredTimerInfoImpl(Some(expiryTimestampMs))).map { obj => getOutputRow(obj) } ImplicitGroupingKeyTracker.removeImplicitKey() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java index b9841ee0f9735..3f0efe2e14aff 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java @@ -55,45 +55,42 @@ public void init( public scala.collection.Iterator handleInputRows( Integer key, scala.collection.Iterator rows, - TimerValues timerValues, - ExpiredTimerInfo expiredTimerInfo) { + TimerValues timerValues) { java.util.List result = new ArrayList<>(); - if (!expiredTimerInfo.isValid()) { - long count = 0; - // Perform various operations on composite types to verify compatibility for the Java API - if (countState.exists()) { - count = countState.get(); - } + long count = 0; + // Perform various operations on composite types to verify compatibility for the Java API + if (countState.exists()) { + count = countState.get(); + } - long numRows = 0; - StringBuilder sb = new StringBuilder(key.toString()); - while (rows.hasNext()) { - numRows++; - String value = rows.next(); - if (keyCountMap.containsKey(value)) { - keyCountMap.updateValue(value, keyCountMap.getValue(value) + 1); - } else { - keyCountMap.updateValue(value, 1L); - } - assertTrue(keyCountMap.containsKey(value)); - keysList.appendValue(value); - sb.append(value); + long numRows = 0; + StringBuilder sb = new StringBuilder(key.toString()); + while (rows.hasNext()) { + numRows++; + String value = rows.next(); + if (keyCountMap.containsKey(value)) { + keyCountMap.updateValue(value, keyCountMap.getValue(value) + 1); + } else { + keyCountMap.updateValue(value, 1L); } + assertTrue(keyCountMap.containsKey(value)); + keysList.appendValue(value); + sb.append(value); + } - scala.collection.Iterator keys = keysList.get(); - while (keys.hasNext()) { - String keyVal = keys.next(); - assertTrue(keyCountMap.containsKey(keyVal)); - assertTrue(keyCountMap.getValue(keyVal) > 0); - } + scala.collection.Iterator keys = keysList.get(); + while (keys.hasNext()) { + String keyVal = keys.next(); + assertTrue(keyCountMap.containsKey(keyVal)); + assertTrue(keyCountMap.getValue(keyVal) > 0); + } - count += numRows; - countState.update(count); - assertEquals(count, (long) countState.get()); + count += numRows; + countState.update(count); + assertEquals(count, (long) countState.get()); - result.add(sb.toString()); - } + result.add(sb.toString()); return CollectionConverters.asScala(result).iterator(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java index 55046a7c0d3df..7e356abf2d05b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java @@ -53,30 +53,27 @@ public void handleInitialState(Integer key, String initialState, TimerValues tim public scala.collection.Iterator handleInputRows( Integer key, scala.collection.Iterator rows, - TimerValues timerValues, - ExpiredTimerInfo expiredTimerInfo) { + TimerValues timerValues) { java.util.List result = new ArrayList<>(); - if (!expiredTimerInfo.isValid()) { - String existingValue = ""; - if (testState.exists()) { - existingValue = testState.get(); - } + String existingValue = ""; + if (testState.exists()) { + existingValue = testState.get(); + } - StringBuilder sb = new StringBuilder(key.toString()); - if (!existingValue.isEmpty()) { - sb.append(existingValue); - } + StringBuilder sb = new StringBuilder(key.toString()); + if (!existingValue.isEmpty()) { + sb.append(existingValue); + } - while (rows.hasNext()) { - sb.append(rows.next()); - } + while (rows.hasNext()) { + sb.append(rows.next()); + } - testState.clear(); - assertFalse(testState.exists()); + testState.clear(); + assertFalse(testState.exists()); - result.add(sb.toString()); - } + result.add(sb.toString()); return CollectionConverters.asScala(result).iterator(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 0aa748f7af93d..293ec52cc8717 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} import org.apache.spark.sql.functions.{explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} import org.apache.spark.sql.streaming.util.StreamManualClock /** Stateful processor of single value state var with non-primitive type */ @@ -40,8 +40,7 @@ class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { val count = _valueState.getOption().getOrElse(TestClass(0L, "dummyKey")).id + 1 _valueState.update(TestClass(count, "dummyKey")) Iterator((key, count.toString)) @@ -62,8 +61,7 @@ class StatefulProcessorWithTTL override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { val count = _countState.getOption().getOrElse(0L) + 1 if (count == 3) { _countState.clear() @@ -89,8 +87,7 @@ class SessionGroupsStatefulProcessor extends override def handleInputRows( key: String, inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = { + timerValues: TimerValues): Iterator[String] = { inputRows.foreach { inputRow => _groupsList.appendValue(inputRow._2) } @@ -112,8 +109,7 @@ class SessionGroupsStatefulProcessorWithTTL extends override def handleInputRows( key: String, inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = { + timerValues: TimerValues): Iterator[String] = { inputRows.foreach { inputRow => _groupsListWithTTL.appendValue(inputRow._2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index 71b8c8ac923d4..20f04cc66c0aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -39,8 +39,7 @@ class TestListStateProcessor override def handleInputRows( key: String, rows: Iterator[InputRow], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { var output = List[(String, String)]() @@ -97,8 +96,7 @@ class ToggleSaveAndEmitProcessor override def handleInputRows( key: String, rows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = { + timerValues: TimerValues): Iterator[String] = { val valueStateOption = _valueState.getOption() if (valueStateOption.isEmpty || !valueStateOption.get) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index d11d8ef9a9b36..409a255ae3e64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -41,8 +41,7 @@ class ListStateTTLProcessor(ttlConfig: TTLConfig) override def handleInputRows( key: String, inputRows: Iterator[InputEvent], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + timerValues: TimerValues): Iterator[OutputEvent] = { var results = List[OutputEvent]() inputRows.foreach { row => @@ -55,7 +54,7 @@ class ListStateTTLProcessor(ttlConfig: TTLConfig) results.iterator } - def processRow( + private def processRow( row: InputEvent, listState: ListStateImplWithTTL[Int]): Iterator[OutputEvent] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index e4e6862f7f937..da4218949a110 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -39,8 +39,7 @@ class TestMapStateProcessor override def handleInputRows( key: String, inputRows: Iterator[InputMapRow], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { + timerValues: TimerValues): Iterator[(String, String, String)] = { var output = List[(String, String, String)]() @@ -74,8 +73,6 @@ class TestMapStateProcessor } output.iterator } - - override def close(): Unit = {} } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index 3794bcc9ea271..022280eb3bcef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -41,8 +41,7 @@ class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig) override def handleInputRows( key: String, inputRows: Iterator[InputEvent], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + timerValues: TimerValues): Iterator[OutputEvent] = { var results = List[OutputEvent]() for (row <- inputRows) { @@ -55,7 +54,7 @@ class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig) results.iterator } - def processRow( + private def processRow( row: InputEvent, mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = { var results = List[OutputEvent]() @@ -119,8 +118,7 @@ class MapStateTTLProcessor(ttlConfig: TTLConfig) override def handleInputRows( key: String, inputRows: Iterator[MapInputEvent], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = { + timerValues: TimerValues): Iterator[MapOutputEvent] = { var results = List[MapOutputEvent]() for (row <- inputRows) { @@ -133,7 +131,7 @@ class MapStateTTLProcessor(ttlConfig: TTLConfig) results.iterator } - def processRow( + private def processRow( row: MapInputEvent, mapState: MapStateImplWithTTL[String, Int]): Iterator[MapOutputEvent] = { var results = List[MapOutputEvent]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala index b1025d9d89494..6888fcba45f3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala @@ -45,8 +45,7 @@ class TestStatefulProcessor override def handleInputRows( key: String, inputRows: Iterator[InputEventRow], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputRow] = { + timerValues: TimerValues): Iterator[OutputRow] = { if (inputRows.isEmpty) { Iterator.empty } else { @@ -70,8 +69,7 @@ class InputCountStatefulProcessor[T] override def handleInputRows( key: String, inputRows: Iterator[T], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[Int] = { + timerValues: TimerValues): Iterator[Int] = { Iterator.single(inputRows.size) } } @@ -86,8 +84,7 @@ class StatefulProcessorEmittingRowsOlderThanWatermark override def handleInputRows( key: String, inputRows: Iterator[InputEventRow], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputRow] = { + timerValues: TimerValues): Iterator[OutputRow] = { Iterator.single( OutputRow( key, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index d141407b4fcd0..300785611fd05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -48,8 +48,7 @@ abstract class StatefulProcessorWithInitialStateTestClass[V] override def handleInputRows( key: String, inputRows: Iterator[InitInputRow], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)] = { + timerValues: TimerValues): Iterator[(String, String, Double)] = { var output = List[(String, String, Double)]() for (row <- inputRows) { if (row.action == "getOption") { @@ -96,8 +95,7 @@ class AccumulateStatefulProcessorWithInitState override def handleInputRows( key: String, inputRows: Iterator[InitInputRow], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)] = { + timerValues: TimerValues): Iterator[(String, String, Double)] = { var output = List[(String, String, Double)]() for (row <- inputRows) { if (row.action == "getOption") { @@ -140,13 +138,6 @@ class StatefulProcessorWithInitialStateProcTimerClass @transient private var _timerState: ValueState[Long] = _ @transient protected var _countState: ValueState[Long] = _ - private def handleProcessingTimeBasedTimers( - key: String, - expiryTimestampMs: Long): Iterator[(String, String)] = { - _timerState.clear() - Iterator((key, "-1")) - } - private def processUnexpiredRows( key: String, currCount: Long, @@ -187,16 +178,19 @@ class StatefulProcessorWithInitialStateProcTimerClass override def handleInputRows( key: String, inputRows: Iterator[String], + timerValues: TimerValues): Iterator[(String, String)] = { + val currCount = _countState.getOption().getOrElse(0L) + val count = currCount + inputRows.size + processUnexpiredRows(key, currCount, count, timerValues) + Iterator((key, count.toString)) + } + + override def handleExpiredTimer( + key: String, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { - if (expiredTimerInfo.isValid()) { - handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) - } else { - val currCount = _countState.getOption().getOrElse(0L) - val count = currCount + inputRows.size - processUnexpiredRows(key, currCount, count, timerValues) - Iterator((key, count.toString)) - } + _timerState.clear() + Iterator((key, "-1")) } } @@ -246,18 +240,20 @@ class StatefulProcessorWithInitialStateEventTimerClass override def handleInputRows( key: String, inputRows: Iterator[(String, Long)], + timerValues: TimerValues): Iterator[(String, Int)] = { + val valuesSeq = inputRows.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, + _maxEventTimeState.getOption().getOrElse(0L)) + processUnexpiredRows(maxEventTimeSec) + Iterator((key, maxEventTimeSec.toInt)) + } + + override def handleExpiredTimer( + key: String, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Int)] = { - if (expiredTimerInfo.isValid()) { - _maxEventTimeState.clear() - Iterator((key, -1)) - } else { - val valuesSeq = inputRows.toSeq - val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, - _maxEventTimeState.getOption().getOrElse(0L)) - processUnexpiredRows(maxEventTimeSec) - Iterator((key, maxEventTimeSec.toInt)) - } + _maxEventTimeState.clear() + Iterator((key, -1)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 257578ee65447..1a7970302e5bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -56,8 +56,7 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { val count = _countState.getOption().getOrElse(0L) + 1 if (count == 3) { _countState.clear() @@ -84,8 +83,7 @@ class RunningCountStatefulProcessorWithTTL override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { val count = _countState.getOption().getOrElse(0L) + 1 if (count == 3) { _countState.clear() @@ -114,8 +112,7 @@ class RunningCountListStatefulProcessor override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { Iterator.empty } } @@ -133,8 +130,7 @@ class RunningCountStatefulProcessorInt override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { val count = _countState.getOption().getOrElse(0) + 1 if (count == 3) { _countState.clear() @@ -148,38 +144,34 @@ class RunningCountStatefulProcessorInt // Class to verify stateful processor usage with adding processing time timers class RunningCountStatefulProcessorWithProcTimeTimer extends RunningCountStatefulProcessor { - private def handleProcessingTimeBasedTimers( - key: String, - expiryTimestampMs: Long): Iterator[(String, String)] = { - _countState.clear() - Iterator((key, "-1")) - } override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { + val currCount = _countState.getOption().getOrElse(0L) + if (currCount == 0 && (key == "a" || key == "c")) { + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + + 5000) + } - if (expiredTimerInfo.isValid()) { - handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) + val count = currCount + 1 + if (count == 3) { + _countState.clear() + Iterator.empty } else { - val currCount = _countState.getOption().getOrElse(0L) - if (currCount == 0 && (key == "a" || key == "c")) { - getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() - + 5000) - } - - val count = currCount + 1 - if (count == 3) { - _countState.clear() - Iterator.empty - } else { - _countState.update(count) - Iterator((key, count.toString)) - } + _countState.update(count) + Iterator((key, count.toString)) } } + + override def handleExpiredTimer( + key: String, + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + _countState.clear() + Iterator((key, "-1")) + } } // Class to verify stateful processor usage with updating processing time timers @@ -194,13 +186,6 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } - private def handleProcessingTimeBasedTimers( - key: String, - expiryTimestampMs: Long): Iterator[(String, String)] = { - _timerState.clear() - Iterator((key, "-1")) - } - protected def processUnexpiredRows( key: String, currCount: Long, @@ -225,50 +210,50 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates override def handleInputRows( key: String, inputRows: Iterator[String], + timerValues: TimerValues): Iterator[(String, String)] = { + val currCount = _countState.getOption().getOrElse(0L) + val count = currCount + inputRows.size + processUnexpiredRows(key, currCount, count, timerValues) + Iterator((key, count.toString)) + } + + override def handleExpiredTimer( + key: String, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { - if (expiredTimerInfo.isValid()) { - handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) - } else { - val currCount = _countState.getOption().getOrElse(0L) - val count = currCount + inputRows.size - processUnexpiredRows(key, currCount, count, timerValues) - Iterator((key, count.toString)) - } + _timerState.clear() + Iterator((key, "-1")) } } class RunningCountStatefulProcessorWithMultipleTimers extends RunningCountStatefulProcessor { - private def handleProcessingTimeBasedTimers( + + override def handleInputRows( key: String, - expiryTimestampMs: Long): Iterator[(String, String)] = { + inputRows: Iterator[String], + timerValues: TimerValues): Iterator[(String, String)] = { val currCount = _countState.getOption().getOrElse(0L) - if (getHandle.listTimers().size == 1) { - _countState.clear() + val count = currCount + inputRows.size + _countState.update(count) + if (getHandle.listTimers().isEmpty) { + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000) + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000) + getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 15000) + assert(getHandle.listTimers().size == 3) } - Iterator((key, currCount.toString)) + Iterator.empty } - override def handleInputRows( + override def handleExpiredTimer( key: String, - inputRows: Iterator[String], timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { - if (expiredTimerInfo.isValid()) { - handleProcessingTimeBasedTimers(key, expiredTimerInfo.getExpiryTimeInMs()) - } else { - val currCount = _countState.getOption().getOrElse(0L) - val count = currCount + inputRows.size - _countState.update(count) - if (getHandle.listTimers().isEmpty) { - getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000) - getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000) - getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 15000) - assert(getHandle.listTimers().size == 3) - } - Iterator.empty + val currCount = _countState.getOption().getOrElse(0L) + if (getHandle.listTimers().size == 1) { + _countState.clear() } + Iterator((key, currCount.toString)) } } @@ -302,18 +287,20 @@ class MaxEventTimeStatefulProcessor override def handleInputRows( key: String, inputRows: Iterator[(String, Long)], + timerValues: TimerValues): Iterator[(String, Int)] = { + val valuesSeq = inputRows.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, + _maxEventTimeState.getOption().getOrElse(0L)) + processUnexpiredRows(maxEventTimeSec) + Iterator((key, maxEventTimeSec.toInt)) + } + + override def handleExpiredTimer( + key: String, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Int)] = { - if (expiredTimerInfo.isValid()) { - _maxEventTimeState.clear() - Iterator((key, -1)) - } else { - val valuesSeq = inputRows.toSeq - val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, - _maxEventTimeState.getOption().getOrElse(0L)) - processUnexpiredRows(maxEventTimeSec) - Iterator((key, maxEventTimeSec.toInt)) - } + _maxEventTimeState.clear() + Iterator((key, -1)) } } @@ -333,8 +320,7 @@ class RunningCountMostRecentStatefulProcessor override def handleInputRows( key: String, inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, String)] = { + timerValues: TimerValues): Iterator[(String, String, String)] = { val count = _countState.getOption().getOrElse(0L) + 1 val mostRecent = _mostRecent.getOption().getOrElse("") @@ -363,8 +349,7 @@ class MostRecentStatefulProcessorWithDeletion override def handleInputRows( key: String, inputRows: Iterator[(String, String)], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { val mostRecent = _mostRecent.getOption().getOrElse("") var output = List[(String, String)]() @@ -383,8 +368,7 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess override def handleInputRows( key: String, inputRows: Iterator[String], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + timerValues: TimerValues): Iterator[(String, String)] = { // Trying to create value state here should fail _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong) Iterator.empty @@ -452,34 +436,35 @@ class TransformWithStateSuite extends StateStoreMetricsTest override def handleInputRows( key: Long, inputRows: Iterator[Long], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[Long] = { + timerValues: TimerValues): Iterator[Long] = { // Eagerly get/set a state variable _myValueState.get() _myValueState.update(1) // Create a timer (but only once) so that we can test timers have their implicit key set if (!hasSetTimer) { - getHandle.registerTimer(0) - hasSetTimer = true + getHandle.registerTimer(0) + hasSetTimer = true } // In both of these cases, we return a lazy iterator that gets/sets state variables. // This is to test that the stateful processor can handle lazy iterators. - // + inputRows.map { r => + _myValueState.get() + _myValueState.update(r) + r + } + } + + override def handleExpiredTimer( + key: Long, + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[Long] = { // The timer uses a Seq(42L) since when the timer fires, inputRows is empty. - if (expiredTimerInfo.isValid()) { - Seq(42L).iterator.map { r => - _myValueState.get() - _myValueState.update(r) - r - } - } else { - inputRows.map { r => - _myValueState.get() - _myValueState.update(r) - r - } + Seq(42L).iterator.map { r => + _myValueState.get() + _myValueState.update(r) + r } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index e2b31de1f66b3..e7b394db0c3c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -99,8 +99,7 @@ class ValueStateTTLProcessor(ttlConfig: TTLConfig) override def handleInputRows( key: String, inputRows: Iterator[InputEvent], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + timerValues: TimerValues): Iterator[OutputEvent] = { var results = List[OutputEvent]() inputRows.foreach { row => @@ -138,8 +137,7 @@ class MultipleValueStatesTTLProcessor( override def handleInputRows( key: String, inputRows: Iterator[InputEvent], - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = { + timerValues: TimerValues): Iterator[OutputEvent] = { var results = List[OutputEvent]() if (key == ttlKey) { From f9a5de475caa6c576cb8b240e72d80559fef3ec6 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 21 Oct 2024 19:56:23 +0800 Subject: [PATCH 080/108] [SPARK-50029][SQL] Make `StaticInvoke` compatible with the method that return `Any` ### What changes were proposed in this pull request? The pr aims to make `StaticInvoke` compatible with the method that return `Any`. ### Why are the changes needed? Currently, our `StaticInvoke` does not support calling the method with a return type signature of `Any`(actually, the type of return value may be `different data type`), while `Invoke` supports it, let's align it. ### Does this PR introduce _any_ user-facing change? No, only for spark developer. ### How was this patch tested? - Add new UT. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48542 from panbingkun/SPARK-50029. Authored-by: panbingkun Signed-off-by: Wenchen Fan --- .../expressions/objects/objects.scala | 7 ++-- .../expressions/ObjectExpressionsSuite.scala | 32 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2e149af099467..9af63a754124c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -322,13 +322,14 @@ case class StaticInvoke( val evaluate = if (returnNullable && !method.getReturnType.isPrimitive) { if (CodeGenerator.defaultValue(dataType) == "null") { s""" - ${ev.value} = $callFunc; + ${ev.value} = ($javaType) $callFunc; ${ev.isNull} = ${ev.value} == null; """ } else { val boxedResult = ctx.freshName("boxedResult") + val boxedJavaType = CodeGenerator.boxedType(dataType) s""" - ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; + $boxedJavaType $boxedResult = ($boxedJavaType) $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -336,7 +337,7 @@ case class StaticInvoke( """ } } else { - s"${ev.value} = $callFunc;" + s"${ev.value} = ($javaType) $callFunc;" } val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 762a4e9166d51..d31e76469f533 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -755,6 +755,31 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val genCode = StaticInvoke(TestFun.getClass, IntegerType, "foo", arguments).genCode(ctx) assert(!genCode.code.toString.contains("boxedResult")) } + + test("StaticInvoke call return `any` method") { + val cls = TestStaticInvokeReturnAny.getClass + Seq((0, IntegerType, true), (1, IntegerType, true), (2, IntegerType, false)).foreach { + case (arg, argDataType, returnNullable) => + val dataType = arg match { + case 0 => ObjectType(classOf[java.lang.Integer]) + case 1 => ShortType + case 2 => ObjectType(classOf[java.lang.Long]) + } + val arguments = Seq(Literal(arg, argDataType)) + val inputTypes = Seq(IntegerType) + val expected = arg match { + case 0 => java.lang.Integer.valueOf(1) + case 1 => 0.toShort + case 2 => java.lang.Long.valueOf(2) + } + val inputRow = InternalRow.fromSeq(Seq(arg)) + checkObjectExprEvaluation( + StaticInvoke(cls, dataType, "func", arguments, inputTypes, + returnNullable = returnNullable), + expected, + inputRow) + } + } } class TestBean extends Serializable { @@ -790,3 +815,10 @@ case object TestFun { def foo(left: Int, right: Int): Int = left + right } +object TestStaticInvokeReturnAny { + def func(input: Int): Any = input match { + case 0 => java.lang.Integer.valueOf(1) + case 1 => 0.toShort + case 2 => java.lang.Long.valueOf(2) + } +} From 3e597e2ec7f61d6944b154afbb5259f6ff7885d5 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 21 Oct 2024 16:45:00 +0200 Subject: [PATCH 081/108] [SPARK-50053][SQL] Turn `_LEGACY_ERROR_TEMP_2104` into `INTERNAL_ERROR` ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for `_LEGACY_ERROR_TEMP_2104`: `INTERNAL_ERROR` ### Why are the changes needed? To improve the error message by assigning proper error condition. ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48577 from itholic/LEGACY_2104. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- common/utils/src/main/resources/error/error-conditions.json | 5 ----- .../org/apache/spark/sql/errors/QueryExecutionErrors.scala | 4 ++-- .../spark/sql/execution/joins/HashedRelationSuite.scala | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 00f288bfc65c3..9b50312539bca 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -7106,11 +7106,6 @@ "Dictionary encoding should not be used because of dictionary overflow." ] }, - "_LEGACY_ERROR_TEMP_2104" : { - "message" : [ - "End of the iterator." - ] - }, "_LEGACY_ERROR_TEMP_2105" : { "message" : [ "Could not allocate memory to grow BytesToBytesMap." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 946d129c47c4f..367d705d9f7dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1081,8 +1081,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def endOfIteratorError(): Throwable = { new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2104", - messageParameters = Map.empty, + errorClass = "INTERNAL_ERROR", + messageParameters = Map("message" -> "End of the iterator."), cause = null) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index e555033b53055..6590deaa47e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -548,8 +548,8 @@ class HashedRelationSuite extends SharedSparkSession { exception = intercept[SparkException] { keyIterator.next() }, - condition = "_LEGACY_ERROR_TEMP_2104", - parameters = Map.empty + condition = "INTERNAL_ERROR", + parameters = Map("message" -> "End of the iterator.") ) assert(buffer.sortWith(_ < _) === randomArray) buffer.clear() From f86df1e37b4770bffddafc0216ae5eb4b3ac25b7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 21 Oct 2024 10:42:13 -0700 Subject: [PATCH 082/108] [SPARK-50036][CORE][PYTHON] Include SPARK_LOG_SCHEMA in the context of REPL shell ### What changes were proposed in this pull request? Before the Change: Users needed to import LOG_SCHEMA to read structured logs as a JSON data source: ``` import org.apache.spark.util.LogUtils.LOG_SCHEMA val logDf = spark.read.schema(LOG_SCHEMA).json("path/to/logs") ``` After the Change: - Renamed for Clarity:`LOG_SCHEMA` has been renamed to `SPARK_LOG_SCHEMA` to make its purpose more clear. - No Import Needed in REPL Shells: You can now use `SPARK_LOG_SCHEMA` directly in REPL environments like spark-shell and pyspark without importing it. Now, you can read structured logs without the import: ``` val logDf = spark.read.schema(SPARK_LOG_SCHEMA).json("path/to/logs") ``` ### Why are the changes needed? Simply the way to query structured logs ### Does this PR introduce _any_ user-facing change? No, this a new feature in Spark 4.0 ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48552 from gengliangwang/import_log_schema. Lead-authored-by: Gengliang Wang Co-authored-by: Hyukjin Kwon Signed-off-by: Gengliang Wang --- .../scala/org/apache/spark/util/LogUtils.scala | 4 ++-- docs/configuration.md | 15 ++++++++------- python/pyspark/logger/__init__.py | 8 ++------ python/pyspark/logger/logger.py | 10 ++++++++++ python/pyspark/shell.py | 1 + python/pyspark/util.py | 16 ---------------- .../scala/org/apache/spark/repl/SparkILoop.scala | 3 ++- .../org/apache/spark/sql/LogQuerySuite.scala | 8 ++++++-- 8 files changed, 31 insertions(+), 34 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala index 5a798ffad3a92..8b41f10339271 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala @@ -29,9 +29,9 @@ object LogUtils { /** * Schema for structured Spark logs. * Example usage: - * val logDf = spark.read.schema(LOG_SCHEMA).json("path/to/logs") + * val logDf = spark.read.schema(SPARK_LOG_SCHEMA).json("path/to/logs") */ - val LOG_SCHEMA: String = """ + val SPARK_LOG_SCHEMA: String = """ |ts TIMESTAMP, |level STRING, |msg STRING, diff --git a/docs/configuration.md b/docs/configuration.md index 3c83ed92c1280..e095ae7a61b22 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -3761,21 +3761,22 @@ Starting from version 4.0.0, `spark-submit` has adopted the [JSON Template Layou To configure the layout of structured logging, start with the `log4j2.properties.template` file. -To query Spark logs using Spark SQL, you can use the following Python code snippet: +To query Spark logs using Spark SQL, you can use the following code snippets: +**Python:** ```python -from pyspark.util import LogUtils +from pyspark.logger import SPARK_LOG_SCHEMA -logDf = spark.read.schema(LogUtils.LOG_SCHEMA).json("path/to/logs") +logDf = spark.read.schema(SPARK_LOG_SCHEMA).json("path/to/logs") ``` -Or using the following Scala code snippet: +**Scala:** ```scala -import org.apache.spark.util.LogUtils.LOG_SCHEMA +import org.apache.spark.util.LogUtils.SPARK_LOG_SCHEMA -val logDf = spark.read.schema(LOG_SCHEMA).json("path/to/logs") +val logDf = spark.read.schema(SPARK_LOG_SCHEMA).json("path/to/logs") ``` - +**Note**: If you're using the interactive shell (pyspark shell or spark-shell), you can omit the import statement in the code because SPARK_LOG_SCHEMA is already available in the shell's context. ## Plain Text Logging If you prefer plain text logging, you have two options: - Disable structured JSON logging by setting the Spark configuration `spark.log.structuredLogging.enabled` to `false`. diff --git a/python/pyspark/logger/__init__.py b/python/pyspark/logger/__init__.py index d8fab8beca8d8..9e629971f0cbe 100644 --- a/python/pyspark/logger/__init__.py +++ b/python/pyspark/logger/__init__.py @@ -18,10 +18,6 @@ """ PySpark logging """ -from pyspark.logger.logger import ( # noqa: F401 - PySparkLogger, -) +from pyspark.logger.logger import PySparkLogger, SPARK_LOG_SCHEMA # noqa: F401 -__all__ = [ - "PySparkLogger", -] +__all__ = ["PySparkLogger", "SPARK_LOG_SCHEMA"] diff --git a/python/pyspark/logger/logger.py b/python/pyspark/logger/logger.py index 975441a9cb572..a2226fd717e0a 100644 --- a/python/pyspark/logger/logger.py +++ b/python/pyspark/logger/logger.py @@ -20,6 +20,16 @@ import json from typing import cast, Optional +SPARK_LOG_SCHEMA = ( + "ts TIMESTAMP, " + "level STRING, " + "msg STRING, " + "context map, " + "exception STRUCT>>," + "logger STRING" +) + class JSONFormatter(logging.Formatter): """ diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 12ff86ecc9ff9..91951b644f6bf 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -30,6 +30,7 @@ import pyspark from pyspark.core.context import SparkContext +from pyspark.logger import SPARK_LOG_SCHEMA # noqa: F401 from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext from pyspark.sql.utils import is_remote diff --git a/python/pyspark/util.py b/python/pyspark/util.py index cca44435efe67..10d68e9fd8bfe 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -118,22 +118,6 @@ def majorMinorVersion(sparkVersion: str) -> Tuple[int, int]: ) -class LogUtils: - """ - Utils for querying structured Spark logs with Spark SQL. - """ - - LOG_SCHEMA = ( - "ts TIMESTAMP, " - "level STRING, " - "msg STRING, " - "context map, " - "exception STRUCT>>," - "logger STRING" - ) - - def fail_on_stopiteration(f: Callable) -> Callable: """ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 861cf5a740ce1..f49e8adcc74af 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -66,7 +66,8 @@ class SparkILoop(in0: BufferedReader, out: PrintWriter) "import org.apache.spark.SparkContext._", "import spark.implicits._", "import spark.sql", - "import org.apache.spark.sql.functions._" + "import org.apache.spark.sql.functions._", + "import org.apache.spark.util.LogUtils.SPARK_LOG_SCHEMA" ) override protected def internalReplAutorunCode(): Seq[String] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LogQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LogQuerySuite.scala index df0fbf15a98ee..873337e7a4242 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LogQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LogQuerySuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.LogUtils.LOG_SCHEMA +import org.apache.spark.util.LogUtils.SPARK_LOG_SCHEMA /** * Test suite for querying Spark logs using SQL. @@ -42,7 +42,11 @@ class LogQuerySuite extends QueryTest with SharedSparkSession with Logging { } private def createTempView(viewName: String): Unit = { - spark.read.schema(LOG_SCHEMA).json(logFile.getCanonicalPath).createOrReplaceTempView(viewName) + spark + .read + .schema(SPARK_LOG_SCHEMA) + .json(logFile.getCanonicalPath) + .createOrReplaceTempView(viewName) } test("Query Spark logs using SQL") { From b4eb03469a90ec5bf65fd51d2bcb41602acad51b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 21 Oct 2024 20:10:41 +0200 Subject: [PATCH 083/108] [SPARK-49967][SQL] Codegen Support for `StructsToJson`(`to_json`) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `StructsToJson`(`to_json`). ### Why are the changes needed? - improve codegen coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: JsonFunctionsSuite#`*to_json*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48467 from panbingkun/SPARK-49967. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../json/JsonExpressionEvalUtils.scala | 56 +++++++++++++++- .../expressions/jsonExpressions.scala | 64 ++++++------------- .../expressions/ExpressionEvalHelper.scala | 2 +- .../expressions/JsonExpressionsSuite.scala | 5 +- .../explain-results/function_to_json.explain | 2 +- .../explain-results/toJSON.explain | 2 +- .../BaseScriptTransformationExec.scala | 4 +- 7 files changed, 78 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index efa5c930b73da..dd7d318f430b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.sql.catalyst.expressions.json +import java.io.CharArrayWriter + import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} -import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, VariantType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils object JsonExpressionEvalUtils { @@ -111,3 +113,51 @@ case class JsonToStructsEvaluator( } } } + +case class StructsToJsonEvaluator( + options: Map[String, String], + inputSchema: DataType, + timeZoneId: Option[String]) extends Serializable { + + @transient + private lazy val writer = new CharArrayWriter() + + @transient + private lazy val gen = new JacksonGenerator( + inputSchema, writer, new JSONOptions(options, timeZoneId.get)) + + // This converts rows to the JSON output according to the given schema. + @transient + private lazy val converter: Any => UTF8String = { + def getAndReset(): UTF8String = { + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + inputSchema match { + case _: StructType => + (row: Any) => + gen.write(row.asInstanceOf[InternalRow]) + getAndReset() + case _: ArrayType => + (arr: Any) => + gen.write(arr.asInstanceOf[ArrayData]) + getAndReset() + case _: MapType => + (map: Any) => + gen.write(map.asInstanceOf[MapData]) + getAndReset() + case _: VariantType => + (v: Any) => + gen.write(v.asInstanceOf[VariantVal]) + getAndReset() + } + } + + final def evaluate(value: Any): Any = { + if (value == null) return null + converter(value) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index d884e76f5256d..ac6c233f7d2ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,16 +30,15 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator} +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern} -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{UTF8String, VariantVal} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils private[this] sealed trait PathInstruction @@ -748,14 +747,15 @@ case class StructsToJson( child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression - with TimeZoneAwareExpression - with CodegenFallback + with RuntimeReplaceable with ExpectsInputTypes - with NullIntolerant + with TimeZoneAwareExpression with QueryErrorsBase { override def nullable: Boolean = true + override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) + def this(options: Map[String, String], child: Expression) = this(options, child, None) // Used in `FunctionRegistry` @@ -767,44 +767,7 @@ case class StructsToJson( timeZoneId = None) @transient - lazy val writer = new CharArrayWriter() - - @transient - lazy val gen = new JacksonGenerator( - inputSchema, writer, new JSONOptions(options, timeZoneId.get)) - - @transient - lazy val inputSchema = child.dataType - - // This converts rows to the JSON output according to the given schema. - @transient - lazy val converter: Any => UTF8String = { - def getAndReset(): UTF8String = { - gen.flush() - val json = writer.toString - writer.reset() - UTF8String.fromString(json) - } - - inputSchema match { - case _: StructType => - (row: Any) => - gen.write(row.asInstanceOf[InternalRow]) - getAndReset() - case _: ArrayType => - (arr: Any) => - gen.write(arr.asInstanceOf[ArrayData]) - getAndReset() - case _: MapType => - (map: Any) => - gen.write(map.asInstanceOf[MapData]) - getAndReset() - case _: VariantType => - (v: Any) => - gen.write(v.asInstanceOf[VariantVal]) - getAndReset() - } - } + private lazy val inputSchema = child.dataType override def dataType: DataType = SQLConf.get.defaultStringType @@ -820,14 +783,23 @@ case class StructsToJson( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(value: Any): Any = converter(value) - override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil override def prettyName: String = "to_json" override protected def withNewChildInternal(newChild: Expression): StructsToJson = copy(child = newChild) + + @transient + private lazy val evaluator = StructsToJsonEvaluator(options, inputSchema, timeZoneId) + + override def replacement: Expression = Invoke( + Literal.create(evaluator, ObjectType(classOf[StructsToJsonEvaluator])), + "evaluate", + dataType, + Seq(child), + Seq(child.dataType) + ) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 02c7ed727a530..184f5a2a9485d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,7 +79,7 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance() val resolver = ResolveTimeZone - val expr = resolver.resolveTimeZones(replace(expression)) + val expr = replace(resolver.resolveTimeZones(expression)) assert(expr.resolved) serializer.deserialize(serializer.serialize(expr)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index edb7b93ecdf68..3a58cb92cecf2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -582,7 +582,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val schema = StructType(StructField("\"quote", IntegerType) :: Nil) val struct = Literal.create(create_row(1), schema) GenerateUnsafeProjection.generate( - StructsToJson(Map.empty, struct, UTC_OPT) :: Nil) + StructsToJson(Map.empty, struct, UTC_OPT).replacement :: Nil) } test("to_json - struct") { @@ -729,8 +729,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from/to json - interval support") { val schema = StructType(StructField("i", CalendarIntervalType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType), - UTC_OPT), + JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType)), InternalRow(new CalendarInterval(12, 1, 0))) Seq(MapType(CalendarIntervalType, IntegerType), MapType(IntegerType, CalendarIntervalType)) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain index cd72b12ee19b6..da90d9c4c6e16 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain @@ -1,2 +1,2 @@ -Project [to_json((timestampFormat,dd/MM/yyyy), d#0, Some(America/Los_Angeles)) AS to_json(d)#0] +Project [invoke(StructsToJsonEvaluator(Map(timestampFormat -> dd/MM/yyyy),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),Some(America/Los_Angeles)).evaluate(d#0)) AS to_json(d)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain index 1698c562732e8..fcb3e173ecaad 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain @@ -1,2 +1,2 @@ -Project [to_json(struct(id, id#0L, a, a#0, b, b#0, d, d#0, e, e#0, f, f#0, g, g#0), Some(America/Los_Angeles)) AS to_json(struct(id, a, b, d, e, f, g))#0] +Project [invoke(StructsToJsonEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true),StructField(d,StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),true),StructField(e,ArrayType(IntegerType,true),true),StructField(f,MapType(StringType,StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),true),true),StructField(g,StringType,true)),Some(America/Los_Angeles)).evaluate(struct(id, id#0L, a, a#0, b, b#0, d, d#0, e, e#0, f, f#0, g, g#0))) AS to_json(struct(id, a, b, d, e, f, g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 2a1554d287a8a..64d2633c31079 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -51,8 +51,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { child.output.map { in => in.dataType match { case _: ArrayType | _: MapType | _: StructType => - new StructsToJson(ioschema.inputSerdeProps.toMap, in) - .withTimeZone(conf.sessionLocalTimeZone) + StructsToJson(ioschema.inputSerdeProps.toMap, in, + Some(conf.sessionLocalTimeZone)).replacement case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) } } From 91ae1022198ba4f5ed3204c7dd959c60f76ad3a4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 22 Oct 2024 08:21:42 +0900 Subject: [PATCH 084/108] [SPARK-50052][PYTHON] Make NumpyArrayConverter support empty str ndarray ### What changes were proposed in this pull request? Make `NumpyArrayConverter` support empty str ndarray ### Why are the changes needed? existing implementation needs the str ndarray to be non-empty, so fails: ``` In [6]: spark.range(1).select(sf.lit(np.array(["a", "b"], np.str_))) Out[6]: DataFrame[ARRAY('a', 'b'): array] In [7]: spark.range(1).select(sf.lit(np.array([], np.int32))) Out[7]: DataFrame[ARRAY(): array] In [8]: spark.range(1).select(sf.lit(np.array([], np.str_))) ... PySparkTypeError: [UNSUPPORTED_NUMPY_ARRAY_SCALAR] The type of array scalar '] ``` ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48575 from zhengruifeng/classic_np_array_conv. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../tests/connect/test_parity_functions.py | 8 +++++++ python/pyspark/sql/tests/test_functions.py | 23 +++++++++++++++++++ python/pyspark/sql/types.py | 17 +++++++------- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index 0a77c5531082a..2004c065659b3 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -34,6 +34,14 @@ def test_function_parity(self): def test_input_file_name_reset_for_rdd(self): super().test_input_file_name_reset_for_rdd() + @unittest.skip("SPARK-50050: Spark Connect should support str ndarray.") + def test_str_ndarray(self): + super().test_str_ndarray() + + @unittest.skip("SPARK-50051: Spark Connect should empty ndarray.") + def test_empty_ndarray(self): + super().test_empty_ndarray() + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index cec6e2ababbdc..657f0b3371568 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1263,6 +1263,29 @@ def test_ndarray_input(self): }, ) + @unittest.skipIf(not have_numpy, "NumPy not installed") + def test_str_ndarray(self): + import numpy as np + + for arr in [ + np.array([], np.str_), + np.array(["a"], np.str_), + np.array([1, 2, 3], np.str_), + ]: + self.assertEqual( + [("a", "array")], + self.spark.range(1).select(F.lit(arr).alias("a")).dtypes, + ) + + @unittest.skipIf(not have_numpy, "NumPy not installed") + def test_empty_ndarray(self): + import numpy as np + + self.assertEqual( + [("a", "array")], + self.spark.range(1).select(F.lit(np.array([], np.int32)).alias("a")).dtypes, + ) + def test_binary_math_function(self): funcs, expected = zip( *[(F.atan2, 0.13664), (F.hypot, 8.07527), (F.pow, 2.14359), (F.pmod, 1.1)] diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e6e2267af0ce1..0f77e95cced1c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -3273,6 +3273,8 @@ def _from_numpy_type_to_java_type( return gateway.jvm.double elif nt == np.dtype("bool"): return gateway.jvm.boolean + elif np.isdtype(nt, np.str_): + return gateway.jvm.String return None @@ -3286,15 +3288,12 @@ def convert(self, obj: "np.ndarray", gateway_client: "GatewayClient") -> "JavaGa assert gateway is not None plist = obj.tolist() - if len(obj) > 0 and isinstance(plist[0], str): - jtpe = gateway.jvm.String - else: - jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway) - if jtpe is None: - raise PySparkTypeError( - errorClass="UNSUPPORTED_NUMPY_ARRAY_SCALAR", - messageParameters={"dtype": str(obj.dtype)}, - ) + jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway) + if jtpe is None: + raise PySparkTypeError( + errorClass="UNSUPPORTED_NUMPY_ARRAY_SCALAR", + messageParameters={"dtype": str(obj.dtype)}, + ) jarr = gateway.new_array(jtpe, len(obj)) for i in range(len(plist)): jarr[i] = plist[i] From 70c9b1f978a2c97e69efa4862a22c48ee00a8ef4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 22 Oct 2024 09:03:26 +0900 Subject: [PATCH 085/108] [SPARK-49552][CONNECT][FOLLOW-UP] Make 'randstr' and 'uniform' deterministic in Scala Client ### What changes were proposed in this pull request? Make 'randstr' and 'uniform' deterministic in Scala Client ### Why are the changes needed? We need to explicitly set the seed in connect clients, to avoid making the output dataframe non-deterministic (see https://github.com/apache/spark/commit/14ba4fc479155611ca39bda6f879c34cc78af2ee) When reviewing https://github.com/apache/spark/pull/48143, I requested the author to set the seed in python client. But at that time, I was not aware of the fact that Spark Connect Scala Client was reusing the same `functions.scala` under `org.apache.spark.sql`. (There were two different files before) So the two functions may cause non-deterministic issues like: ``` scala> val df = spark.range(10).select(randstr(lit(10)).as("r")) Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties df: org.apache.spark.sql.package.DataFrame = [r: string] scala> df.show() +----------+ | r| +----------+ |5bhIk72PJa| |tuhC50Di38| |PxwfWzdT3X| |sWkmSyWboh| |uZMS4htmM0| |YMxMwY5wdQ| |JDaWSiBwDD| |C7KQ20WE7t| |IwSSqWOObg| |jDF2Ndfy8q| +----------+ scala> df.show() +----------+ | r| +----------+ |fpnnoLJbOA| |qerIKpYPif| |PvliXYIALD| |xK3fosAvOp| |WK12kfkPXq| |2UcdyAEbNm| |HEkl4rMtV1| |PCaH4YJuYo| |JuuXEHSp5i| |jSLjl8ug8S| +----------+ ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? after this fix: ``` scala> val df = spark.range(10).select(randstr(lit(10)).as("r")) df: org.apache.spark.sql.package.DataFrame = [r: string] scala> df.show() +----------+ | r| +----------+ |Gri9B9X8zI| |gfhpGD8PcV| |FDaXofTzlN| |p7ciOScWpu| |QZiEbF5q7c| |9IhRoXmTUM| |TeSEG1EKSN| |B7nLw5iedL| |uFZo1WPLPT| |46E2LVCxxl| +----------+ scala> df.show() +----------+ | r| +----------+ |Gri9B9X8zI| |gfhpGD8PcV| |FDaXofTzlN| |p7ciOScWpu| |QZiEbF5q7c| |9IhRoXmTUM| |TeSEG1EKSN| |B7nLw5iedL| |uFZo1WPLPT| |46E2LVCxxl| +----------+ ``` ### Was this patch authored or co-authored using generative AI tooling? no Closes #48558 from zhengruifeng/sql_rand_str_seed. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 ++++-- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 4a9a20efd3a56..ece9d638e7c61 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1901,7 +1901,8 @@ object functions { * @group string_funcs * @since 4.0.0 */ - def randstr(length: Column): Column = Column.fn("randstr", length) + def randstr(length: Column): Column = + randstr(length, lit(SparkClassUtils.random.nextLong)) /** * Returns a string of the specified length whose characters are chosen uniformly at random from @@ -3767,7 +3768,8 @@ object functions { * @group math_funcs * @since 4.0.0 */ - def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max) + def uniform(min: Column, max: Column): Column = + uniform(min, max, lit(SparkClassUtils.random.nextLong)) /** * Returns a random value with independent and identically distributed (i.i.d.) values with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 39c839ae5a518..975a82e26f4eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -507,12 +507,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } // Here we exercise some error cases. val df = Seq((0)).toDF("a") - var expr = uniform(lit(10), lit("a")) + var expr = uniform(lit(10), lit("a"), lit(1)) checkError( intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "sqlExpr" -> "\"uniform(10, a)\"", + "sqlExpr" -> "\"uniform(10, a, 1)\"", "paramIndex" -> "second", "inputSql" -> "\"a\"", "inputType" -> "\"STRING\"", @@ -525,7 +525,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { callSitePattern = "", startIndex = 0, stopIndex = 0)) - expr = uniform(col("a"), lit(10)) + expr = uniform(col("a"), lit(10), lit(1)) checkError( intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", @@ -533,7 +533,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputName" -> "`min`", "inputType" -> "integer or floating-point", "inputExpr" -> "\"a\"", - "sqlExpr" -> "\"uniform(a, 10)\""), + "sqlExpr" -> "\"uniform(a, 10, 1)\""), context = ExpectedContext( contextType = QueryContextType.DataFrame, fragment = "uniform", From 09b7bdff48a6657850c29d576e339571487e7a90 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 22 Oct 2024 09:49:27 +0900 Subject: [PATCH 086/108] [SPARK-50064][PYTHON][TESTS] Make pysaprk-ml-connect tests passing without optional dependencies ### What changes were proposed in this pull request? This PR proposes to make pysaprk-ml-connect tests passing without optional dependencies ### Why are the changes needed? To make the tests passing without optional dependencies. See https://github.com/apache/spark/actions/runs/11447673972/job/31849621508 ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually ran it locally ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48587 from HyukjinKwon/SPARK-50064. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../tests/connect/test_connect_evaluation.py | 25 ++++++------- .../ml/tests/connect/test_connect_feature.py | 25 ++++++------- .../ml/tests/connect/test_connect_pipeline.py | 37 ++++++++++--------- .../tests/connect/test_connect_summarizer.py | 19 +++++----- .../ml/tests/connect/test_connect_tuning.py | 31 ++++++++-------- 5 files changed, 68 insertions(+), 69 deletions(-) diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py b/python/pyspark/ml/tests/connect/test_connect_evaluation.py index 359a77bbcb20f..9acf5ae0ac44d 100644 --- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py @@ -30,19 +30,18 @@ if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_evaluation import EvaluationTestsMixin - -@unittest.skipIf( - not should_test_connect or not have_torcheval, - connect_requirement_message or "torcheval is required", -) -class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() + @unittest.skipIf( + not should_test_connect or not have_torcheval, + connect_requirement_message or "torcheval is required", + ) + class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() + + def tearDown(self) -> None: + self.spark.stop() if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py index c786ce2f87d0f..c1d02050097b2 100644 --- a/python/pyspark/ml/tests/connect/test_connect_feature.py +++ b/python/pyspark/ml/tests/connect/test_connect_feature.py @@ -32,19 +32,18 @@ if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin - -@unittest.skipIf( - not should_test_connect or not have_sklearn, - connect_requirement_message or sklearn_requirement_message, -) -class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() + @unittest.skipIf( + not should_test_connect or not have_sklearn, + connect_requirement_message or sklearn_requirement_message, + ) + class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() + + def tearDown(self) -> None: + self.spark.stop() if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index 4105f593f170e..7733af7631e92 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -22,9 +22,6 @@ from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message -if should_test_connect: - from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin - torch_requirement_message = None have_torch = True try: @@ -33,23 +30,27 @@ have_torch = False torch_requirement_message = "torch is required" +if should_test_connect: + from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin -@unittest.skipIf( - not should_test_connect or not have_torch or is_remote_only(), - connect_requirement_message - or torch_requirement_message - or "Requires PySpark core library in Spark Connect server", -) -class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = ( - SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) - .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") - .getOrCreate() - ) + @unittest.skipIf( + not should_test_connect or not have_torch or is_remote_only(), + connect_requirement_message + or torch_requirement_message + or "Requires PySpark core library in Spark Connect server", + ) + class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): + def setUp(self) -> None: + self.spark = ( + SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ) + .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") + .getOrCreate() + ) - def tearDown(self) -> None: - self.spark.stop() + def tearDown(self) -> None: + self.spark.stop() if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py b/python/pyspark/ml/tests/connect/test_connect_summarizer.py index 1cfd2ed229e5b..9c737c96ee87a 100644 --- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py +++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py @@ -24,16 +24,15 @@ if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_summarizer import SummarizerTestsMixin - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() + @unittest.skipIf(not should_test_connect, connect_requirement_message) + class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() + + def tearDown(self) -> None: + self.spark.stop() if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index d5fcb93099b6e..be3ca067bab1b 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -26,21 +26,22 @@ if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin - -@unittest.skipIf( - not should_test_connect or is_remote_only(), - connect_requirement_message or "Requires PySpark core library in Spark Connect server", -) -class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = ( - SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) - .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") - .getOrCreate() - ) - - def tearDown(self) -> None: - self.spark.stop() + @unittest.skipIf( + not should_test_connect or is_remote_only(), + connect_requirement_message or "Requires PySpark core library in Spark Connect server", + ) + class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): + def setUp(self) -> None: + self.spark = ( + SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ) + .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") + .getOrCreate() + ) + + def tearDown(self) -> None: + self.spark.stop() if __name__ == "__main__": From d5419ab0960e1a32e5505c1045b1563a829f6eeb Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Tue, 22 Oct 2024 09:31:49 +0800 Subject: [PATCH 087/108] [SPARK-50058][SQL] Factor out plan normalization functions to later use those in single-pass Analyzer testing ### What changes were proposed in this pull request? Factor out plan normalization functions to later use those in single-pass Analyzer testing. ### Why are the changes needed? In the single-pass Analyzer [proposal](https://docs.google.com/document/d/1dWxvrJV-0joGdLtWbvJ0uNyTocDMJ90rPRNWa4T56Og/edit?tab=t.0#bookmark=id.yjvozwxpjc5q) we discussed one of the ways to test the single-pass Analyzer - to run tests in dual mode (both single-pass and fixed-point Analyzers) and compare the logical plans. To do that we need to perform plan normalization. Context: https://issues.apache.org/jira/browse/SPARK-49834 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48584 from vladimirg-db/vladimirg-db/factor-out-normalize-plan. Authored-by: Vladimir Golubev Signed-off-by: Wenchen Fan --- .../sql/catalyst/plans/NormalizePlan.scala | 127 ++++++++++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 95 +------------ 2 files changed, 132 insertions(+), 90 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala new file mode 100644 index 0000000000000..3b691f4f87778 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.catalyst.analysis.GetViewColumnByNameAndOrdinal +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical._ + +object NormalizePlan extends PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = + normalizePlan(normalizeExprIds(plan)) + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + def normalizeExprIds(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case s: ScalarSubquery => + s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0)) + case s: LateralSubquery => + s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0)) + case e: Exists => + e.copy(plan = normalizeExprIds(e.plan), exprId = ExprId(0)) + case l: ListQuery => + l.copy(plan = normalizeExprIds(l.plan), exprId = ExprId(0)) + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case OuterReference(a: AttributeReference) => + OuterReference(AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) + case OuterReference(a: Alias) => + OuterReference(Alias(a.child, a.name)(exprId = ExprId(0))) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) + case lv: NamedLambdaVariable => + lv.copy(exprId = ExprId(0), value = null) + case udf: PythonUDF => + udf.copy(resultId = ExprId(0)) + case udaf: PythonUDAF => + udaf.copy(resultId = ExprId(0)) + case a: FunctionTableSubqueryArgumentExpression => + a.copy(plan = normalizeExprIds(a.plan), exprId = ExprId(0)) + } + } + + /** + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. + */ + def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan transform { + case Filter(condition: Expression, child: LogicalPlan) => + Filter( + splitConjunctivePredicates(condition) + .map(rewriteBinaryComparison) + .sortBy(_.hashCode()) + .reduce(And), + child + ) + case sample: Sample => + sample.copy(seed = 0L) + case Join(left, right, joinType, condition, hint) if condition.isDefined => + val newJoinType = joinType match { + case ExistenceJoin(a: Attribute) => + val newAttr = AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + ExistenceJoin(newAttr) + case other => other + } + + val newCondition = + splitConjunctivePredicates(condition.get) + .map(rewriteBinaryComparison) + .sortBy(_.hashCode()) + .reduce(And) + Join(left, right, newJoinType, Some(newCondition), hint) + case Project(projectList, child) => + val projList = projectList + .map { e => + e.transformUp { + case g: GetViewColumnByNameAndOrdinal => g.copy(viewDDL = None) + } + } + .asInstanceOf[Seq[NamedExpression]] + Project(projList, child) + case c: KeepAnalyzedQuery => c.storeAnalyzedQuery() + } + } + + /** + * Rewrite [[BinaryComparison]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + * 3. (a > b), (b < a) + */ + private def rewriteBinaryComparison(condition: Expression): Expression = condition match { + case EqualTo(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case EqualNullSafe(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + case _ => condition // Don't reorder. + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e90a956ab4fde..37baa66049de3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -23,10 +23,9 @@ import org.scalatest.Tag import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.analysis.{GetViewColumnByNameAndOrdinal, SimpleAnalyzer} +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf @@ -65,40 +64,8 @@ trait CodegenInterpretedPlanTest extends PlanTest { */ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { self: Suite => - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case s: ScalarSubquery => - s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0)) - case s: LateralSubquery => - s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0)) - case e: Exists => - e.copy(plan = normalizeExprIds(e.plan), exprId = ExprId(0)) - case l: ListQuery => - l.copy(plan = normalizeExprIds(l.plan), exprId = ExprId(0)) - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case OuterReference(a: AttributeReference) => - OuterReference(AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - case OuterReference(a: Alias) => - OuterReference(Alias(a.child, a.name)(exprId = ExprId(0))) - case ae: AggregateExpression => - ae.copy(resultId = ExprId(0)) - case lv: NamedLambdaVariable => - lv.copy(exprId = ExprId(0), value = null) - case udf: PythonUDF => - udf.copy(resultId = ExprId(0)) - case udaf: PythonUDAF => - udaf.copy(resultId = ExprId(0)) - case a: FunctionTableSubqueryArgumentExpression => - a.copy(plan = normalizeExprIds(a.plan), exprId = ExprId(0)) - } - } + protected def normalizeExprIds(plan: LogicalPlan): LogicalPlan = + NormalizePlan.normalizeExprIds(plan) protected def rewriteNameFromAttrNullability(plan: LogicalPlan): LogicalPlan = { plan.transformAllExpressions { @@ -107,60 +74,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s } } - /** - * Normalizes plans: - * - Filter the filter conditions that appear in a plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. - * - Sample the seed will replaced by 0L. - * - Join conditions will be resorted by hashCode. - */ - protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { - plan transform { - case Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteBinaryComparison) - .sortBy(_.hashCode()).reduce(And), child) - case sample: Sample => - sample.copy(seed = 0L) - case Join(left, right, joinType, condition, hint) if condition.isDefined => - val newJoinType = joinType match { - case ExistenceJoin(a: Attribute) => - val newAttr = AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - ExistenceJoin(newAttr) - case other => other - } - - val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteBinaryComparison) - .sortBy(_.hashCode()).reduce(And) - Join(left, right, newJoinType, Some(newCondition), hint) - case Project(projectList, child) => - val projList = projectList.map { e => - e.transformUp { - case g: GetViewColumnByNameAndOrdinal => g.copy(viewDDL = None) - } - }.asInstanceOf[Seq[NamedExpression]] - Project(projList, child) - case c: KeepAnalyzedQuery => c.storeAnalyzedQuery() - } - } - - /** - * Rewrite [[BinaryComparison]] operator to keep order. The following cases will be - * equivalent: - * 1. (a = b), (b = a); - * 2. (a <=> b), (b <=> a). - * 3. (a > b), (b < a) - */ - private def rewriteBinaryComparison(condition: Expression): Expression = condition match { - case EqualTo(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) - case EqualNullSafe(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) - case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) - case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) - case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) - case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) - case _ => condition // Don't reorder. - } + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = + NormalizePlan.normalizePlan(plan) /** Fails the test if the two plans do not match */ protected def comparePlans( From 2c904e4b682e6f67049272ed38ca94894afd9242 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Tue, 22 Oct 2024 11:46:02 +0900 Subject: [PATCH 088/108] [SPARK-49849][CONNECT][PYTHON] API compatibility check for Structured Streaming Query Management ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Structured Streaming Query Management functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48581 from itholic/compat_streaming_query. Authored-by: Haejoon Lee Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/streaming/query.py | 2 +- python/pyspark/sql/streaming/query.py | 2 +- .../sql/tests/test_connect_compatibility.py | 38 +++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 204d16106482d..c4e9c512ec58e 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -203,7 +203,7 @@ def active(self) -> List[StreamingQuery]: active.__doc__ = PySparkStreamingQueryManager.active.__doc__ - def get(self, id: str) -> Optional[StreamingQuery]: + def get(self, id: str) -> Optional["StreamingQuery"]: cmd = pb2.StreamingQueryManagerCommand() cmd.get_query = id response = self._execute_streaming_query_manager_cmd(cmd) diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 28274e9fadc2a..d2f9f0957e0ae 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -491,7 +491,7 @@ def active(self) -> List[StreamingQuery]: """ return [StreamingQuery(jsq) for jsq in self._jsqm.active()] - def get(self, id: str) -> Optional[StreamingQuery]: + def get(self, id: str) -> Optional["StreamingQuery"]: """ Returns an active query from this :class:`SparkSession`. diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index c5e66fde018c5..cf5c610213407 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -34,6 +34,8 @@ from pyspark.sql.group import GroupedData as ClassicGroupedData import pyspark.sql.avro.functions as ClassicAvro import pyspark.sql.protobuf.functions as ClassicProtobuf +from pyspark.sql.streaming.query import StreamingQuery as ClassicStreamingQuery +from pyspark.sql.streaming.query import StreamingQueryManager as ClassicStreamingQueryManager if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -49,6 +51,10 @@ from pyspark.sql.connect.group import GroupedData as ConnectGroupedData import pyspark.sql.connect.avro.functions as ConnectAvro import pyspark.sql.connect.protobuf.functions as ConnectProtobuf + from pyspark.sql.connect.streaming.query import StreamingQuery as ConnectStreamingQuery + from pyspark.sql.connect.streaming.query import ( + StreamingQueryManager as ConnectStreamingQueryManager, + ) class ConnectCompatibilityTestsMixin: @@ -401,6 +407,22 @@ def test_avro_compatibility(self): expected_missing_classic_methods, ) + def test_streaming_query_compatibility(self): + """Test Streaming Query compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = set() + self.check_compatibility( + ClassicStreamingQuery, + ConnectStreamingQuery, + "StreamingQuery", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + def test_protobuf_compatibility(self): """Test Protobuf compatibility between classic and connect.""" expected_missing_connect_properties = set() @@ -423,6 +445,22 @@ def test_protobuf_compatibility(self): expected_missing_classic_methods, ) + def test_streaming_query_manager_compatibility(self): + """Test Streaming Query Manager compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = {"close"} + self.check_compatibility( + ClassicStreamingQueryManager, + ConnectStreamingQueryManager, + "StreamingQueryManager", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): From e7cdb5a6714520fdd871b068b76231d8b0de75ff Mon Sep 17 00:00:00 2001 From: Neil Ramaswamy Date: Tue, 22 Oct 2024 11:07:24 +0800 Subject: [PATCH 089/108] [SPARK-49944][DOCS] Fix broken main.js import and fix image links for streaming documentation ### What changes were proposed in this pull request? We use the `rel_path_to_root` Jekyll variable in front of all paths that require it. ### Why are the changes needed? Currently, our import to `main.js` and AnchorJS are broken in the Spark 4.0.0-2 preview. Also, images aren't appearing for the Structured Streaming doc pages. See the [ASF issue](https://issues.apache.org/jira/browse/SPARK-49944) for more detail. You can see how the pages are broken [here](https://spark.apache.org/docs/4.0.0-preview2/streaming/getting-started.html); here's a screenshot, for example: image ### Does this PR introduce _any_ user-facing change? The preview documentation will now have correctly rendered code blocks, and images will appear. ### How was this patch tested? Local testing. Please build the docs site if you would like to verify. It now looks like: image ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48438 from neilramaswamy/nr/fix-broken-streaming-links-images. Authored-by: Neil Ramaswamy Signed-off-by: Kent Yao --- docs/_layouts/global.html | 6 +++--- docs/streaming/apis-on-dataframes-and-datasets.md | 10 +++++----- docs/streaming/getting-started.md | 7 ++++--- docs/streaming/performance-tips.md | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index a85fd16451469..f5a20dd441b0e 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -28,7 +28,7 @@ - + @@ -198,8 +198,8 @@

    {{ page.title }}

    crossorigin="anonymous"> - - + +