Skip to content

Commit

Permalink
[SPARK-2210] boolean cast on boolean value should be removed.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jun 20, 2014
1 parent 278ec8a commit c4e543d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ trait HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
case Cast(e, BooleanType) => Not(Equals(e, Literal(0)))
case Cast(e, dataType) if e.dataType == BooleanType =>
Cast(If(e, Literal(1), Literal(0)), dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, Equals}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive

/**
* A set of tests that validate type promotion rules.
* A set of tests that validate type promotion and coercion rules.
*/
class HiveTypeCoercionSuite extends HiveComparisonTest {
val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")
Expand All @@ -28,4 +32,23 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
}
}

test("[SPARK-2210] boolean cast on boolean value should be removed") {
val q = "select cast(cast(key=0 as boolean) as boolean) from src"
val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head

// No cast expression introduced
project.transformAllExpressions { case c: Cast =>
assert(false, "unexpected cast " + c)
c
}

// Only one Equals
var numEquals = 0
project.transformAllExpressions { case e: Equals =>
numEquals += 1
e
}
assert(numEquals === 1)
}
}

0 comments on commit c4e543d

Please sign in to comment.