Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9192][SQL] add initialization phase for nondeterministic expression #7535

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._
import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -78,7 +79,9 @@ class Analyzer(
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*)
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic)
)

/**
Expand Down Expand Up @@ -910,6 +913,34 @@ class Analyzer(
Project(finalProjectList, withWindow)
}
}

/**
* Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter,
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Project => p
case f: Filter => f

// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e =>
val ne = e match {
case n: NamedExpression => n
case _ => Alias(e, "_nondeterministic")()
}
new TreeNodeRef(e) -> ne
}.toMap
val newPlan = p.transformExpressions { case e =>
nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
}
val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand All @@ -38,10 +37,10 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}

def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).nonEmpty
case e: Generator => e
}).length > 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a un-related but small fix: check multiple should use length > 1 instead of nonEmpty

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops

}

def checkAnalysis(plan: LogicalPlan): Unit = {
Expand Down Expand Up @@ -137,13 +136,21 @@ trait CheckAnalysis {
s"""
|Failure when resolving conflicting references in Join:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)

case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")

case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
failAnalysis(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you make it seem like a user error here, but if we get here, we have a bug in the system.

maybe we should just add throw illegalstateexception

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to do it in another PR, as here in not the only place that need to update in CheckAnalysis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @cloud-fan .Can it support for join operator? Sometimes we can use some nondeterministic(i.e. RAND) expression to eval some pointless join keys(with respect to business logic) avoiding data skew.
For example

SELECT src.key, src.value, src1.value 
FROM src 
JOIN src1
ON UPPER((CASE WHEN (src.key IS NULL OR src.key = '' ) THEN CAST( (-RAND() * 10000000 ) AS string ) ELSE src.key END )) = UPPER(src1.key)

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have Repartition operator to do this job, maybe you can try that instead of doing it manually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we need to change this sql manually after this behavior.

s"""nondeterministic expressions are only allowed in Project or Filter, found:
| ${o.expressions.map(_.prettyString).mkString(",")}
|in operator ${operator.simpleString}
""".stripMargin)

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,26 @@ trait Unevaluable extends Expression {
* An expression that is nondeterministic.
*/
trait Nondeterministic extends Expression {
override def deterministic: Boolean = false
final override def deterministic: Boolean = false
final override def foldable: Boolean = false

private[this] var initialized = false

final def initialize(): Unit = {
if (!initialized) {
initInternal()
initialized = true
}
}

protected def initInternal(): Unit

final override def eval(input: InternalRow = null): Any = {
require(initialized, "nondeterministic expression should be initialized before evaluate")
evalInternal(input)
}

protected def evalInternal(input: InternalRow): Any
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize()
case _ =>
})

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null

Expand All @@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize()
case _ =>
})

private[this] val exprArray = expressions.toArray
private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length)
def currentValue: InternalRow = mutableRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ object InterpretedPredicate {
create(BindReferences.bindReference(expression, inputSchema))

def create(expression: Expression): (InternalRow => Boolean) = {
expression.foreach {
case n: Nondeterministic => n.initialize()
case _ =>
}
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic {

/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize it.
* reset every time we serialize and deserialize and initialize it.
*/
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
@transient protected var rng: XORShiftRandom = _

override protected def initInternal(): Unit = {
rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
}

override def nullable: Boolean = false

Expand All @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic {

/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
case class Rand(seed: Long) extends RDG {
override def eval(input: InternalRow): Double = rng.nextDouble()
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()

def this() = this(Utils.random.nextLong())

Expand All @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG {

/** Generate a random column with i.i.d. gaussian random distribution. */
case class Randn(seed: Long) extends RDG {
override def eval(input: InternalRow): Double = rng.nextGaussian()
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

def this() = this(Utils.random.nextLong())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
Expand Down Expand Up @@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override lazy val statistics: Statistics = {
val limit = limitExpr.eval(null).asInstanceOf[Int]
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
Statistics(sizeInBytes = sizeInBytes)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@

package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

// todo: remove this and use AnalysisTest instead.
object AnalysisSuite {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
Expand Down Expand Up @@ -55,7 +52,7 @@ object AnalysisSuite {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())

val nestedRelation = LocalRelation(
Expand All @@ -81,8 +78,7 @@ object AnalysisSuite {
}


class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
import AnalysisSuite._
class AnalysisSuite extends AnalysisTest {

test("union project *") {
val plan = (1 to 100)
Expand All @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}

assert(caseInsensitiveAnalyzer.execute(plan).resolved)
assertAnalysisSuccess(plan)
}

test("check project's resolved") {
Expand All @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
}

test("analyze project") {
assert(
caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))

assert(
caseSensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))

val e = intercept[AnalysisException] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
}
assert(e.getMessage().toLowerCase.contains("cannot resolve"))

assert(
caseInsensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))

assert(
caseInsensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
checkAnalysis(
Project(Seq(UnresolvedAttribute("a")), testRelation),
Project(testRelation.output, testRelation))

checkAnalysis(
Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Project(testRelation.output, testRelation))

assertAnalysisError(
Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Seq("cannot resolve"))

checkAnalysis(
Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Project(testRelation.output, testRelation),
caseSensitive = false)

checkAnalysis(
Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
Project(testRelation.output, testRelation),
caseSensitive = false)
}

test("resolve relations") {
val e = intercept[RuntimeException] {
caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
}
assert(e.getMessage == "Table Not Found: tAbLe")
assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe"))

assert(
caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation)

assert(
caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false)

assert(
caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false)
}


test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())

val plan = caseInsensitiveAnalyzer.execute(
testRelation2.select(
'a / Literal(2) as 'div1,
Expand All @@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList

// StringType will be promoted into Double
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double
assert(pl(3).dataType == DoubleType)
assert(pl(4).dataType == DoubleType)
}

test("pull out nondeterministic expressions from unary LogicalPlan") {
val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
Project(testRelation.output,
RepartitionByExpression(Seq(projected.toAttribute),
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
}
Loading