Skip to content

Commit

Permalink
[SPARK-9192][SQL] add initialization phase for nondeterministic expre…
Browse files Browse the repository at this point in the history
…ssion

Currently nondeterministic expression is broken without a explicit initialization phase.

Let me take `MonotonicallyIncreasingID` as an example. This expression need a mutable state to remember how many times it has been evaluated, so we use `transient var count: Long` there. By being transient, the `count` will be reset to 0 and **only** to 0 when serialize and deserialize it, as deserialize transient variable will result to default value. There is *no way* to use another initial value for `count`, until we add the explicit initialization phase.

Another use case is local execution for `LocalRelation`, there is no serialize and deserialize phase and thus we can't reset mutable states for it.

Author: Wenchen Fan <[email protected]>

Closes #7535 from cloud-fan/init and squashes the following commits:

6c6f332 [Wenchen Fan] add test
ef68ff4 [Wenchen Fan] fix comments
9eac85e [Wenchen Fan] move init code to interpreted class
bb7d838 [Wenchen Fan] pulls out nondeterministic expressions into a project
b4a4fc7 [Wenchen Fan] revert a refactor
86fee36 [Wenchen Fan] add initialization phase for nondeterministic expression
  • Loading branch information
cloud-fan authored and rxin committed Jul 25, 2015
1 parent e2ec018 commit 2c94d0f
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._
import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -78,7 +79,9 @@ class Analyzer(
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*)
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic)
)

/**
Expand Down Expand Up @@ -910,6 +913,34 @@ class Analyzer(
Project(finalProjectList, withWindow)
}
}

/**
* Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter,
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Project => p
case f: Filter => f

// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e =>
val ne = e match {
case n: NamedExpression => n
case _ => Alias(e, "_nondeterministic")()
}
new TreeNodeRef(e) -> ne
}.toMap
val newPlan = p.transformExpressions { case e =>
nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
}
val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand All @@ -38,10 +37,10 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}

def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).nonEmpty
case e: Generator => e
}).length > 1
}

def checkAnalysis(plan: LogicalPlan): Unit = {
Expand Down Expand Up @@ -137,13 +136,21 @@ trait CheckAnalysis {
s"""
|Failure when resolving conflicting references in Join:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)

case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")

case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
failAnalysis(
s"""nondeterministic expressions are only allowed in Project or Filter, found:
| ${o.expressions.map(_.prettyString).mkString(",")}
|in operator ${operator.simpleString}
""".stripMargin)

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,26 @@ trait Unevaluable extends Expression {
* An expression that is nondeterministic.
*/
trait Nondeterministic extends Expression {
override def deterministic: Boolean = false
final override def deterministic: Boolean = false
final override def foldable: Boolean = false

private[this] var initialized = false

final def initialize(): Unit = {
if (!initialized) {
initInternal()
initialized = true
}
}

protected def initInternal(): Unit

final override def eval(input: InternalRow = null): Any = {
require(initialized, "nondeterministic expression should be initialized before evaluate")
evalInternal(input)
}

protected def evalInternal(input: InternalRow): Any
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize()
case _ =>
})

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null

Expand All @@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize()
case _ =>
})

private[this] val exprArray = expressions.toArray
private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length)
def currentValue: InternalRow = mutableRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ object InterpretedPredicate {
create(BindReferences.bindReference(expression, inputSchema))

def create(expression: Expression): (InternalRow => Boolean) = {
expression.foreach {
case n: Nondeterministic => n.initialize()
case _ =>
}
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic {

/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize it.
* reset every time we serialize and deserialize and initialize it.
*/
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
@transient protected var rng: XORShiftRandom = _

override protected def initInternal(): Unit = {
rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
}

override def nullable: Boolean = false

Expand All @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic {

/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
case class Rand(seed: Long) extends RDG {
override def eval(input: InternalRow): Double = rng.nextDouble()
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()

def this() = this(Utils.random.nextLong())

Expand All @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG {

/** Generate a random column with i.i.d. gaussian random distribution. */
case class Randn(seed: Long) extends RDG {
override def eval(input: InternalRow): Double = rng.nextGaussian()
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

def this() = this(Utils.random.nextLong())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
Expand Down Expand Up @@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override lazy val statistics: Statistics = {
val limit = limitExpr.eval(null).asInstanceOf[Int]
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
Statistics(sizeInBytes = sizeInBytes)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@

package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

// todo: remove this and use AnalysisTest instead.
object AnalysisSuite {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
Expand Down Expand Up @@ -55,7 +52,7 @@ object AnalysisSuite {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())

val nestedRelation = LocalRelation(
Expand All @@ -81,8 +78,7 @@ object AnalysisSuite {
}


class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
import AnalysisSuite._
class AnalysisSuite extends AnalysisTest {

test("union project *") {
val plan = (1 to 100)
Expand All @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}

assert(caseInsensitiveAnalyzer.execute(plan).resolved)
assertAnalysisSuccess(plan)
}

test("check project's resolved") {
Expand All @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
}

test("analyze project") {
assert(
caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))

assert(
caseSensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))

val e = intercept[AnalysisException] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
}
assert(e.getMessage().toLowerCase.contains("cannot resolve"))

assert(
caseInsensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))

assert(
caseInsensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
checkAnalysis(
Project(Seq(UnresolvedAttribute("a")), testRelation),
Project(testRelation.output, testRelation))

checkAnalysis(
Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Project(testRelation.output, testRelation))

assertAnalysisError(
Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Seq("cannot resolve"))

checkAnalysis(
Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Project(testRelation.output, testRelation),
caseSensitive = false)

checkAnalysis(
Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Project(testRelation.output, testRelation),
caseSensitive = false)
}

test("resolve relations") {
val e = intercept[RuntimeException] {
caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
}
assert(e.getMessage == "Table Not Found: tAbLe")
assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe"))

assert(
caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation)

assert(
caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false)

assert(
caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false)
}


test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())

val plan = caseInsensitiveAnalyzer.execute(
testRelation2.select(
'a / Literal(2) as 'div1,
Expand All @@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList

// StringType will be promoted into Double
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double
assert(pl(3).dataType == DoubleType)
assert(pl(4).dataType == DoubleType)
}

test("pull out nondeterministic expressions from unary LogicalPlan") {
val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
Project(testRelation.output,
RepartitionByExpression(Seq(projected.toAttribute),
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
}
Loading

0 comments on commit 2c94d0f

Please sign in to comment.