Skip to content

Commit

Permalink
[SPARK-19309][SQL] disable common subexpression elimination for condi…
Browse files Browse the repository at this point in the history
…tional expressions

## What changes were proposed in this pull request?

As I pointed out in #15807 (comment) , the current subexpression elimination framework has a problem, it always evaluates all common subexpressions at the beginning, even they are inside conditional expressions and may not be accessed.

Ideally we should implement it like scala lazy val, so we only evaluate it when it gets accessed at lease once. #15837 tries this approach, but it seems too complicated and may introduce performance regression.

This PR simply stops common subexpression elimination for conditional expressions, with some cleanup.

## How was this patch tested?

regression test

Author: Wenchen Fan <[email protected]>

Closes #16659 from cloud-fan/codegen.
  • Loading branch information
cloud-fan committed Jan 23, 2017
1 parent 772035e commit de6ad3d
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,34 @@ class EquivalentExpressions {
/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored.
*/
def addExprTree(
root: Expression,
ignoreLeaf: Boolean = true,
skipReferenceToExpressions: Boolean = true): Unit = {
val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
def addExprTree(expr: Expression): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
root.find(_.isInstanceOf[LambdaVariable]).isDefined
// There are some special expressions that we should not recurse into children.
expr.find(_.isInstanceOf[LambdaVariable]).isDefined

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
val shouldRecurse = root match {
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false
case _: ReferenceToExpressions if skipReferenceToExpressions => false
case _ => true
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
def childrenToRecurse: Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
// `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
case c: CaseWhenCodegen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
}
if (!skip && !addExpr(root) && shouldRecurse) {
root.children.foreach(addExprTree(_, ignoreLeaf))

if (!skip && !addExpr(expr)) {
childrenToRecurse.foreach(addExprTree)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): UnsafeProjection = {
create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -726,18 +726,18 @@ class CodegenContext {
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
expressions.foreach(equivalentExpressions.addExprTree)

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val codes = commonExprs.map { e =>
val expr = e.head
// Generate the code for this expression tree.
val code = expr.genCode(this)
val state = SubExprEliminationState(code.isNull, code.value)
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(subExprEliminationExprs.put(_, state))
code.code.trim
eval.code.trim
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
}
Expand All @@ -747,7 +747,7 @@ class CodegenContext {
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

Expand All @@ -761,13 +761,13 @@ class CodegenContext {
val value = s"${fnName}Value"

// Generate the code for this expression tree and wrap it in a function.
val code = expr.genCode(this)
val eval = expr.genCode(this)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${code.code.trim}
| $isNull = ${code.isNull};
| $value = ${code.value};
| ${eval.code.trim}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
""".stripMargin

Expand All @@ -780,9 +780,6 @@ class CodegenContext {
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
// very often. The reason it is not loaded is because of a prior branch.
// 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-17160: field names are properly escaped by AssertTrue") {
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
}

test("should not apply common subexpression elimination on conditional expressions") {
val row = InternalRow(null)
val bound = BoundReference(0, IntegerType, true)
val assertNotNull = AssertNotNull(bound, Nil)
val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull))
val projection = GenerateUnsafeProjection.generate(
Seq(expr), subexpressionEliminationEnabled = true)
// should not throw exception
projection(row)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add2 = Add(add, add)

var equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
equivalence.addExprTree(abs, true)
equivalence.addExprTree(add2, true)
equivalence.addExprTree(add)
equivalence.addExprTree(abs)
equivalence.addExprTree(add2)

// Should only have one equivalence for `one + two`
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
Expand All @@ -115,41 +115,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val mul2 = Multiply(mul, mul)
val sqrt = Sqrt(mul2)
val sum = Add(mul2, sqrt)
equivalence.addExprTree(mul, true)
equivalence.addExprTree(mul2, true)
equivalence.addExprTree(sqrt, true)
equivalence.addExprTree(sum, true)
equivalence.addExprTree(mul)
equivalence.addExprTree(mul2)
equivalence.addExprTree(sqrt)
equivalence.addExprTree(sum)

// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
assert(equivalence.getEquivalentExprs(mul).size == 3)
assert(equivalence.getEquivalentExprs(mul2).size == 3)
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
assert(equivalence.getEquivalentExprs(sum).size == 1)

// Some expressions inspired by TPCH-Q1
// sum(l_quantity) as sum_qty,
// sum(l_extendedprice) as sum_base_price,
// sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
// sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
// avg(l_extendedprice) as avg_price,
// avg(l_discount) as avg_disc
equivalence = new EquivalentExpressions
val quantity = Literal(1)
val price = Literal(1.1)
val discount = Literal(.24)
val tax = Literal(0.1)
equivalence.addExprTree(quantity, false)
equivalence.addExprTree(price, false)
equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
equivalence.addExprTree(
Multiply(
Multiply(price, Subtract(Literal(1), discount)),
Add(Literal(1), tax)), false)
equivalence.addExprTree(price, false)
equivalence.addExprTree(discount, false)
// quantity, price, discount and (price * (1 - discount))
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4)
}

test("Expression equivalence - non deterministic") {
Expand All @@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
val add = Add(two, fallback)

val equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}

test("Children of conditional expressions") {
val condition = And(Literal(true), Literal(false))
val add = Add(Literal(1), Literal(2))
val ifExpr = If(condition, add, add)

val equivalence = new EquivalentExpressions
equivalence.addExprTree(ifExpr)
// the `add` inside `If` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
// only ifExpr and its predicate expression
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression(
override lazy val aggBufferAttributes: Seq[AttributeReference] =
bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])

private def serializeToBuffer(expr: Expression): Seq[Expression] = {
bufferSerializer.map(_.transform {
case _: BoundReference => expr
})
}

override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
serializeToBuffer(zero)
}

override lazy val updateExpressions: Seq[Expression] = {
Expand All @@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression(
"reduce",
bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil)

bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
serializeToBuffer(reduced)
}

override lazy val mergeExpressions: Seq[Expression] = {
Expand All @@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression(
"merge",
bufferExternalType,
leftBuffer :: rightBuffer :: Nil)

bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
serializeToBuffer(merged)
}

override lazy val evaluateExpression: Expression = {
Expand All @@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression(
outputExternalType,
bufferDeserializer :: Nil)

val outputSerializeExprs = outputSerializer.map(_.transform {
case _: BoundReference => resultObj
})

dataType match {
case s: StructType =>
case _: StructType =>
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
val struct = If(
IsNull(objRef),
Literal.create(null, dataType),
CreateStruct(outputSerializer))
ReferenceToExpressions(struct, resultObj :: Nil)
If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs))
case _ =>
assert(outputSerializer.length == 1)
outputSerializer.head transform {
case b: BoundReference => resultObj
}
assert(outputSerializeExprs.length == 1)
outputSerializeExprs.head
}
}

Expand Down

0 comments on commit de6ad3d

Please sign in to comment.