diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e916887187dc8..a723e92114b32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer @@ -78,7 +79,9 @@ class Analyzer( GlobalAggregates :: UnresolvedHavingClauseAttributes :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*) + extendedResolutionRules : _*), + Batch("Nondeterministic", Once, + PullOutNondeterministic) ) /** @@ -910,6 +913,34 @@ class Analyzer( Project(finalProjectList, withWindow) } } + + /** + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them away at the outer Project. + */ + object PullOutNondeterministic extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Project => p + case f: Filter => f + + // todo: It's hard to write a general rule to pull out nondeterministic expressions + // from LogicalPlan, currently we only do it for UnaryNode which has same output + // schema with its child. + case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne + }.toMap + val newPlan = p.transformExpressions { case e => + nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + } + val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 81d473c1130f7..a373714832962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -38,10 +37,10 @@ trait CheckAnalysis { throw new AnalysisException(msg) } - def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { - case e: Generator => true - }).nonEmpty + case e: Generator => e + }).length > 1 } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -137,13 +136,21 @@ trait CheckAnalysis { s""" |Failure when resolving conflicting references in Join: |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") + case o if o.expressions.exists(!_.deterministic) && + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + failAnalysis( + s"""nondeterministic expressions are only allowed in Project or Filter, found: + | ${o.expressions.map(_.prettyString).mkString(",")} + |in operator ${operator.simpleString} + """.stripMargin) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3f72e6e184db1..cb4c3f24b2721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -196,7 +196,26 @@ trait Unevaluable extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - override def deterministic: Boolean = false + final override def deterministic: Boolean = false + final override def foldable: Boolean = false + + private[this] var initialized = false + + final def initialize(): Unit = { + if (!initialized) { + initInternal() + initialized = true + } + } + + protected def initInternal(): Unit + + final override def eval(input: InternalRow = null): Any = { + require(initialized, "nondeterministic expression should be initialized before evaluate") + evalInternal(input) + } + + protected def evalInternal(input: InternalRow): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index fb873e7e99547..c1ed9cf7ed6a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3f1bd2a925fe7..5bfe1cad24a3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -30,6 +30,10 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index aef24a5486466..8f30519697a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is - * reset every time we serialize and deserialize it. + * reset every time we serialize and deserialize and initialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + @transient protected var rng: XORShiftRandom = _ + + override protected def initInternal(): Unit = { + rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + } override def nullable: Boolean = false @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ case class Rand(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextDouble() + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() def this() = this(Utils.random.nextLong()) @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG { /** Generate a random column with i.i.d. gaussian random distribution. */ case class Randn(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextGaussian() + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 57a12820fa4c6..8e1a236e2988c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { - val limit = limitExpr.eval(null).asInstanceOf[Int] + val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum Statistics(sizeInBytes = sizeInBytes) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7e67427237a65..ed645b618dc9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +// todo: remove this and use AnalysisTest instead. object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -55,7 +52,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -81,8 +78,7 @@ object AnalysisSuite { } -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisSuite extends AnalysisTest { test("union project *") { val plan = (1 to 100) @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer.execute(plan).resolved) + assertAnalysisSuccess(plan) } test("check project's resolved") { @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { } test("analyze project") { - assert( - caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === - Project(testRelation.output, testRelation)) - - assert( - caseSensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - val e = intercept[AnalysisException] { - caseSensitiveAnalyze( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) - } - assert(e.getMessage().toLowerCase.contains("cannot resolve")) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) + checkAnalysis( + Project(Seq(UnresolvedAttribute("a")), testRelation), + Project(testRelation.output, testRelation)) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation)) + + assertAnalysisError( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Seq("cannot resolve")) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) } test("resolve relations") { - val e = intercept[RuntimeException] { - caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) - } - assert(e.getMessage == "Table Not Found: tAbLe") + assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) - assert( - caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) } - test("divide should be casted into fractional types") { - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, @@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } + + test("pull out nondeterministic expressions from unary LogicalPlan") { + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + RepartitionByExpression(Seq(projected.toAttribute), + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala new file mode 100644 index 0000000000000..fdb4f28950daf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.types._ + +trait AnalysisTest extends PlanTest { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + + val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { + val caseSensitiveConf = new SimpleCatalystConf(true) + val caseInsensitiveConf = new SimpleCatalystConf(false) + + val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) + val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) + + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } -> + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + } + + protected def getAnalyzer(caseSensitive: Boolean) = { + if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer + } + + protected def checkAnalysis( + inputPlan: LogicalPlan, + expectedPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + val actualPlan = analyzer.execute(inputPlan) + analyzer.checkAnalysis(actualPlan) + comparePlans(actualPlan, expectedPlan) + } + + protected def assertAnalysisSuccess( + inputPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + + protected def assertAnalysisError( + inputPlan: LogicalPlan, + expectedErrors: Seq[String], + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + // todo: make sure we throw AnalysisException during analysis + val e = intercept[Exception] { + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + expectedErrors.forall(e.getMessage.contains) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 4930219aa63cb..852a8b235f127 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -64,6 +64,10 @@ trait ExpressionEvalHelper { } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } expression.eval(inputRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 2645eb1854bce..eca36b3274420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -37,17 +37,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with /** * Record ID within each partition. By being transient, count's value is reset to 0 every time - * we serialize and deserialize it. + * we serialize and deserialize and initialize it. */ - @transient private[this] var count: Long = 0L + @transient private[this] var count: Long = _ - @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + @transient private[this] var partitionMask: Long = _ + + override protected def initInternal(): Unit = { + count = 0L + partitionMask = TaskContext.getPartitionId().toLong << 33 + } override def nullable: Boolean = false override def dataType: DataType = LongType - override def eval(input: InternalRow): Long = { + override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 53ddd47e3e0c1..61ef079d89af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -33,9 +33,13 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId() + @transient private[this] var partitionId: Int = _ - override def eval(input: InternalRow): Int = partitionId + override protected def initInternal(): Unit = { + partitionId = TaskContext.getPartitionId() + } + + override protected def evalInternal(input: InternalRow): Int = partitionId override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId")