Skip to content

Commit

Permalink
Removed the default eval implementation from Expression, and added a …
Browse files Browse the repository at this point in the history
…bunch of override's in classes I touched.
  • Loading branch information
rxin committed Apr 7, 2014
1 parent 0307db0 commit 0a83b8f
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
import org.apache.spark.sql.catalyst.trees.TreeNode

Expand All @@ -36,34 +37,41 @@ case class UnresolvedRelation(
databaseName: Option[String],
tableName: String,
alias: Option[String] = None) extends BaseRelation {
def output = Nil
override def output = Nil
override lazy val resolved = false
}

/**
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
def exprId = throw new UnresolvedException(this, "exprId")
def dataType = throw new UnresolvedException(this, "dataType")
def nullable = throw new UnresolvedException(this, "nullable")
def qualifiers = throw new UnresolvedException(this, "qualifiers")
override def exprId = throw new UnresolvedException(this, "exprId")
override def dataType = throw new UnresolvedException(this, "dataType")
override def nullable = throw new UnresolvedException(this, "nullable")
override def qualifiers = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false

def newInstance = this
def withQualifiers(newQualifiers: Seq[String]) = this
override def newInstance = this
override def withQualifiers(newQualifiers: Seq[String]) = this

// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString: String = s"'$name"
}

case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
def exprId = throw new UnresolvedException(this, "exprId")
def dataType = throw new UnresolvedException(this, "dataType")
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
def nullable = throw new UnresolvedException(this, "nullable")
def qualifiers = throw new UnresolvedException(this, "qualifiers")
def references = children.flatMap(_.references).toSet
override def nullable = throw new UnresolvedException(this, "nullable")
override def references = children.flatMap(_.references).toSet
override lazy val resolved = false

// Unresolved functions are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = s"'$name(${children.mkString(",")})"
}

Expand All @@ -79,15 +87,15 @@ case class Star(
mapFunction: Attribute => Expression = identity[Attribute])
extends Attribute with trees.LeafNode[Expression] {

def name = throw new UnresolvedException(this, "exprId")
def exprId = throw new UnresolvedException(this, "exprId")
def dataType = throw new UnresolvedException(this, "dataType")
def nullable = throw new UnresolvedException(this, "nullable")
def qualifiers = throw new UnresolvedException(this, "qualifiers")
override def name = throw new UnresolvedException(this, "exprId")
override def exprId = throw new UnresolvedException(this, "exprId")
override def dataType = throw new UnresolvedException(this, "dataType")
override def nullable = throw new UnresolvedException(this, "nullable")
override def qualifiers = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false

def newInstance = this
def withQualifiers(newQualifiers: Seq[String]) = this
override def newInstance = this
override def withQualifiers(newQualifiers: Seq[String]) = this

def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match {
Expand All @@ -104,5 +112,9 @@ case class Star(
mappedAttributes
}

// Star gets expanded at runtime so we never evaluate a Star.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = table.map(_ + ".").getOrElse("") + "*"
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ abstract class Expression extends TreeNode[Expression] {
def references: Set[Attribute]

/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
def eval(input: Row = null): EvaluatedType

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.errors.TreeNodeException

abstract sealed class SortDirection
case object Ascending extends SortDirection
case object Descending extends SortDirection
Expand All @@ -26,7 +28,12 @@ case object Descending extends SortDirection
* transformations over expression will descend into its child.
*/
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
def dataType = child.dataType
def nullable = child.nullable
override def dataType = child.dataType
override def nullable = child.nullable

// SortOrder itself is never evaluated.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException

abstract class AggregateExpression extends Expression {
self: Product =>
Expand All @@ -28,6 +29,13 @@ abstract class AggregateExpression extends Expression {
* of input rows/
*/
def newInstance(): AggregateFunction

/**
* [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are
* replaced with a physical aggregate operator at runtime.
*/
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._

object NamedExpression {
Expand Down Expand Up @@ -58,9 +59,9 @@ abstract class Attribute extends NamedExpression {

def withQualifiers(newQualifiers: Seq[String]): Attribute

def references = Set(this)
def toAttribute = this
def newInstance: Attribute
override def references = Set(this)
}

/**
Expand All @@ -77,15 +78,15 @@ case class Alias(child: Expression, name: String)
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
extends NamedExpression with trees.UnaryNode[Expression] {

type EvaluatedType = Any
override type EvaluatedType = Any

override def eval(input: Row) = child.eval(input)

def dataType = child.dataType
def nullable = child.nullable
def references = child.references
override def dataType = child.dataType
override def nullable = child.nullable
override def references = child.references

def toAttribute = {
override def toAttribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
} else {
Expand Down Expand Up @@ -127,7 +128,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
h
}

def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)

/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
Expand All @@ -143,13 +144,17 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
/**
* Returns a copy of this [[AttributeReference]] with new qualifiers.
*/
def withQualifiers(newQualifiers: Seq[String]) = {
override def withQualifiers(newQualifiers: Seq[String]) = {
if (newQualifiers == qualifiers) {
this
} else {
AttributeReference(name, dataType, nullable)(exprId, newQualifiers)
}
}

// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString: String = s"$name#${exprId.id}$typeSuffix"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.plans.physical

import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder}
import org.apache.spark.sql.catalyst.types.IntegerType

/**
Expand Down Expand Up @@ -139,12 +140,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression
with Partitioning {

def children = expressions
def references = expressions.flatMap(_.references).toSet
def nullable = false
def dataType = IntegerType
override def children = expressions
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType

lazy val clusteringSet = expressions.toSet
private[this] lazy val clusteringSet = expressions.toSet

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
Expand All @@ -158,6 +159,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case h: HashPartitioning if h == this => true
case _ => false
}

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
Expand All @@ -168,17 +172,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
* partition.
* - Each partition will have a `min` and `max` row, relative to the given ordering. All rows
* that are in between `min` and `max` in this `ordering` will reside in this partition.
*
* This class extends expression primarily so that transformations over expression will descend
* into its child.
*/
case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
extends Expression
with Partitioning {

def children = ordering
def references = ordering.flatMap(_.references).toSet
def nullable = false
def dataType = IntegerType
override def children = ordering
override def references = ordering.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType

lazy val clusteringSet = ordering.map(_.child).toSet
private[this] lazy val clusteringSet = ordering.map(_.child).toSet

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
Expand All @@ -195,4 +202,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case r: RangePartitioning if r == this => true
case _ => false
}

override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ class ExpressionEvaluationSuite extends FunSuite {
(null, false, null) ::
(null, null, null) :: Nil)

def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) {
def booleanLogicTest(
name: String,
op: (Expression, Expression) => Expression,
truthTable: Seq[(Any, Any, Any)]) {
test(s"3VL $name") {
truthTable.foreach {
case (l,r,answer) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.types.IntegerType
import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType}

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
Expand Down

0 comments on commit 0a83b8f

Please sign in to comment.