Skip to content

Commit

Permalink
Merge pull request #3 from yhuai/evalauteLiteralsInExpressions
Browse files Browse the repository at this point in the history
Evalaute literals in expressions
  • Loading branch information
marmbrus committed Jan 6, 2014
2 parents 01c00c2 + 5c14857 commit b1acb36
Show file tree
Hide file tree
Showing 32 changed files with 436 additions and 77 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/analysis/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
* from a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case q: LogicalPlan if childIsFullyResolved(q) =>
logger.trace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/catalyst/analysis/unresolved.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
def exprId = throw new UnresolvedException(this, "exprId")
def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
def nullable = throw new UnresolvedException(this, "nullable")
def qualifiers = throw new UnresolvedException(this, "qualifiers")
def references = children.flatMap(_.references).toSet
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/errors/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ package object errors {

override def getMessage: String = {
val treeString = tree.toString
s"${super.getMessage}, tree:${if(treeString contains "\n") "\n" else " "}$tree"
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/execution/TestShark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object TestShark extends SharkInstance {
* hive test cases assume the system is set up.
*/
private def rewritePaths(cmd: String): String =
if(cmd.toUpperCase startsWith "LOAD")
if (cmd.toUpperCase startsWith "LOAD")
cmd.replaceAll("\\.\\.", hiveDevHome.getCanonicalPath)
else
cmd
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/execution/Transform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ case class Transform(
val readerThread = new Thread("Transform OutoutReader") {
override def run() {
var curLine = reader.readLine()
while(curLine != null) {
while (curLine != null) {
outputLines += buildRow(curLine.split("\t"))
curLine = reader.readLine()
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/catalyst/execution/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ case class Aggregate(

def apply(input: Seq[Row]): Unit = {
val evaluatedExpr = expr.map(Evaluate(_, input))
if(evaluatedExpr.map(_ != null).reduceLeft(_ && _))
if (evaluatedExpr.map(_ != null).reduceLeft(_ && _))
seen += evaluatedExpr
}

Expand All @@ -78,7 +78,7 @@ case class Aggregate(
var result: Any = null

def apply(input: Seq[Row]): Unit = {
if(result == null)
if (result == null)
result = Evaluate(expr, input)
}
}
Expand Down Expand Up @@ -163,7 +163,7 @@ case class SparkAggregate(aggregateExprs: Seq[NamedExpression], child: SharkPlan
val count = sc.accumulable(0)

def apply(input: Seq[Row]): Unit =
if(Evaluate(expr, input) != null)
if (Evaluate(expr, input) != null)
count += 1

def result: Any = count.value.toLong
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/execution/basicOperators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ case class Sort(sortExprs: Seq[SortOrder], child: SharkPlan) extends UnaryNode {
sys.error(s"Comparison not yet implemented for: $curDataType")
}

if(comparison != 0) return comparison
if (comparison != 0) return comparison
i += 1
}
return 0
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/catalyst/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ case class BroadcastNestedLoopJoin(
}
val outputRows = if (matchedRows.size > 0) {
matchedRows
} else if(joinType == LeftOuter || joinType == FullOuter) {
} else if (joinType == LeftOuter || joinType == FullOuter) {
Vector(buildRow(streamedRow ++ Array.fill(right.output.size)(null)))
} else {
Vector()
Expand All @@ -105,7 +105,7 @@ case class BroadcastNestedLoopJoin(

val includedBroadcastTuples = streamedPlusMatches.map(_._2)
val allIncludedBroadcastTuples =
if(includedBroadcastTuples.count == 0)
if (includedBroadcastTuples.count == 0)
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
else
streamedPlusMatches.map(_._2).reduce(_ ++ _)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/catalyst/expressions/BoundAttribute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object BindReferences extends Rule[SharkPlan] {
nonLeaf.transformExpressions {
case a: AttributeReference => attachTree(a, "Binding attribute") {
val inputTuple = nonLeaf.children.indexWhere(_.output contains a)
val ordinal = if(inputTuple == -1) -1 else nonLeaf.children(inputTuple).output.indexWhere(_ == a)
if(ordinal == -1) {
val ordinal = if (inputTuple == -1) -1 else nonLeaf.children(inputTuple).output.indexWhere(_ == a)
if (ordinal == -1) {
logger.debug(s"No binding found for $a given input ${nonLeaf.children.map(_.output.mkString("{", ",", "}")).mkString(",")}")
a
} else {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/catalyst/expressions/Cast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import types.DataType

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"CAST($child, $dataType)"
}
22 changes: 11 additions & 11 deletions src/main/scala/catalyst/expressions/Evaluate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ object Evaluate extends Logging {
@inline
def n1(e: Expression, f: ((Numeric[Any], Any) => Any)): Any = {
val evalE = eval(e)
if(evalE == null)
if (evalE == null)
null
else
e.dataType match {
Expand Down Expand Up @@ -54,7 +54,7 @@ object Evaluate extends Logging {

val evalE1 = eval(e1)
val evalE2 = eval(e2)
if(evalE1 == null || evalE2 == null)
if (evalE1 == null || evalE2 == null)
null
else
e1.dataType match {
Expand Down Expand Up @@ -87,7 +87,7 @@ object Evaluate extends Logging {

val evalE1 = eval(e1)
val evalE2 = eval(e2)
if(evalE1 == null || evalE2 == null)
if (evalE1 == null || evalE2 == null)
null
else
e1.dataType match {
Expand All @@ -106,7 +106,7 @@ object Evaluate extends Logging {
if (e1.dataType != e2.dataType) throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
val evalE1 = eval(e1)
val evalE2 = eval(e2)
if(evalE1 == null || evalE2 == null)
if (evalE1 == null || evalE2 == null)
null
else
e1.dataType match {
Expand Down Expand Up @@ -143,7 +143,7 @@ object Evaluate extends Logging {
case Subtract(l, r) => n2(l,r, _.minus(_, _))
case Multiply(l, r) => n2(l,r, _.times(_, _))
// Divide & remainder implementation are different for fractional and integral dataTypes.
case Divide(l, r) if(l.dataType == DoubleType || l.dataType == FloatType) => f2(l,r, _.div(_, _))
case Divide(l, r) if (l.dataType == DoubleType || l.dataType == FloatType) => f2(l,r, _.div(_, _))
case Divide(l, r) => i2(l,r, _.quot(_, _))
// Remainder is only allowed on Integral types.
case Remainder(l, r) => i2(l,r, _.rem(_, _))
Expand All @@ -153,7 +153,7 @@ object Evaluate extends Logging {
case Equals(l, r) =>
val left = eval(l)
val right = eval(r)
if(left == null || right == null)
if (left == null || right == null)
null
else
left == right
Expand Down Expand Up @@ -229,18 +229,18 @@ object Evaluate extends Logging {
case And(l,r) =>
val left = eval(l)
val right = eval(r)
if(left == false || right == false)
if (left == false || right == false)
false
else if(left == null || right == null )
else if (left == null || right == null )
null
else
true
case Or(l,r) =>
val left = eval(l)
val right = eval(r)
if(left == true || right == true)
if (left == true || right == true)
true
else if(left == null || right == null)
else if (left == null || right == null)
null
else
false
Expand All @@ -261,7 +261,7 @@ object Evaluate extends Logging {
case other => throw new OptimizationException(other, "evaluation not implemented")
}

logger.debug(s"Evaluated $e => $result of type ${if(result == null) "null" else result.getClass.getName}, expected: ${e.dataType}")
logger.debug(s"Evaluated $e => $result of type ${if (result == null) "null" else result.getClass.getName}, expected: ${e.dataType}")
result
}
}
13 changes: 13 additions & 0 deletions src/main/scala/catalyst/expressions/Expression.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ abstract class Expression extends TreeNode[Expression] {
self: Product =>

def dataType: DataType
/**
* foldable is used to indicate if an expression can be folded.
* Right now, we consider expressions listed below as foldable expressions.
* - A Coalesce is foldable if all of its children are foldable
* - A BinaryExpression is foldable if its both left and right child are foldable.
* - A Not, isNull, or isNotNull is foldable if its child is foldable.
* - A Literal is foldable.
* - A Cast or UnaryMinus is foldable if its child is foldable.
*/
// TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
def foldable: Boolean = false
def nullable: Boolean
def references: Set[Attribute]

Expand All @@ -30,6 +41,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

def symbol: String

override def foldable = left.foldable && right.foldable

def references = left.references ++ right.references

override def toString = s"($left $symbol $right)"
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/expressions/SortOrder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ case object Descending extends SortDirection
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
def dataType = child.dataType
def nullable = child.nullable
override def toString = s"$child ${if(direction == Ascending) "ASC" else "DESC"}"
override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
}
1 change: 1 addition & 0 deletions src/main/scala/catalyst/expressions/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import catalyst.types._

abstract class AggregateExpression extends Expression {
self: Product =>

}

/**
Expand Down
9 changes: 3 additions & 6 deletions src/main/scala/catalyst/expressions/arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@ import catalyst.analysis.UnresolvedException

case class UnaryMinus(child: Expression) extends UnaryExpression {
def dataType = child.dataType
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"-$child"
}

abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>
def nullable = left.nullable || right.nullable

override lazy val resolved =
left.resolved && right.resolved && left.dataType == right.dataType

def dataType = {
if(!resolved)
if (!resolved)
throw new UnresolvedException(
this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
left.dataType
Expand All @@ -26,25 +28,20 @@ abstract class BinaryArithmetic extends BinaryExpression {

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "+"
def nullable = left.nullable || right.nullable
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "-"
def nullable = left.nullable || right.nullable
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "*"
def nullable = left.nullable || right.nullable
}

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "/"
def nullable = left.nullable || right.nullable
}

case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "%"
def nullable = left.nullable || right.nullable
}
5 changes: 4 additions & 1 deletion src/main/scala/catalyst/expressions/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ object Literal {
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
case null => Literal(null, NullType)
}
}

Expand All @@ -26,8 +28,9 @@ object IntegerLiteral {
}

case class Literal(value: Any, dataType: DataType) extends LeafExpression {
override def foldable = true
def nullable = false
def references = Set.empty

override def toString = if(value != null) value.toString else "null"
override def toString = if (value != null) value.toString else "null"
}
4 changes: 2 additions & 2 deletions src/main/scala/catalyst/expressions/namedExpressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
* Returns a copy of this [[AttributeReference]] with changed nullability.
*/
def withNullability(newNullability: Boolean) =
if(nullable == newNullability)
if (nullable == newNullability)
this
else
AttributeReference(name, dataType, newNullability)(exprId, qualifiers)
Expand All @@ -104,7 +104,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
* Returns a copy of this [[AttributeReference]] with new qualifiers.
*/
def withQualifiers(newQualifiers: Seq[String]) =
if(newQualifiers == qualifiers)
if (newQualifiers == qualifiers)
this
else
AttributeReference(name, dataType, nullable)(exprId, newQualifiers)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/catalyst/expressions/nullFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
def nullable = !children.exists(!_.nullable)

def references = children.flatMap(_.references).toSet
// Coalesce is foldable if all children are foldable.
override def foldable = !children.exists(!_.foldable)

// Only resolved if all the children are of the same type.
override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/catalyst/expressions/predicates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ trait Predicate extends Expression {

abstract class BinaryPredicate extends BinaryExpression with Predicate {
self: Product =>

def nullable = left.nullable || right.nullable
}

case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
def references = child.references
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"NOT $child"
}
Expand Down Expand Up @@ -55,10 +55,12 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar

case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
def references = child.references
override def foldable = child.foldable
def nullable = false
}

case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
def references = child.references
override def foldable = child.foldable
def nullable = false
}
Loading

0 comments on commit b1acb36

Please sign in to comment.