From 14c9238aa7173ba663a999ef320d8cffb73306c4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Apr 2014 18:38:44 -0700 Subject: [PATCH 1/4] [sql] Rename execution/aggregates.scala Aggregate.scala, and added a bunch of private[this] to variables. Author: Reynold Xin Closes #348 from rxin/aggregate and squashes the following commits: f4bc36f [Reynold Xin] Rename execution/aggregates.scala Aggregate.scala, and added a bunch of private[this] to variables. --- .../{aggregates.scala => Aggregate.scala} | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{aggregates.scala => Aggregate.scala} (92%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 0890faa33b507..3a4f071eebedf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -56,9 +56,9 @@ case class Aggregate( // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. - val childOutput = child.output + private[this] val childOutput = child.output - def output = aggregateExpressions.map(_.toAttribute) + override def output = aggregateExpressions.map(_.toAttribute) /** * An aggregate that needs to be computed for each row in a group. @@ -75,7 +75,7 @@ case class Aggregate( /** A list of aggregates that need to be computed for each group. */ @transient - lazy val computedAggregates = aggregateExpressions.flatMap { agg => + private[this] lazy val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { case a: AggregateExpression => ComputedAggregate( @@ -87,10 +87,10 @@ case class Aggregate( /** The schema of the result of all aggregate evaluations */ @transient - lazy val computedSchema = computedAggregates.map(_.resultAttribute) + private[this] lazy val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - def newAggregateBuffer(): Array[AggregateFunction] = { + private[this] def newAggregateBuffer(): Array[AggregateFunction] = { val buffer = new Array[AggregateFunction](computedAggregates.length) var i = 0 while (i < computedAggregates.length) { @@ -102,7 +102,7 @@ case class Aggregate( /** Named attributes used to substitute grouping attributes into the final result. */ @transient - lazy val namedGroups = groupingExpressions.map { + private[this] lazy val namedGroups = groupingExpressions.map { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute } @@ -112,7 +112,7 @@ case class Aggregate( * expression into the final result expression. */ @transient - lazy val resultMap = + private[this] lazy val resultMap = (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap /** @@ -120,13 +120,13 @@ case class Aggregate( * output rows given a group and the result of all aggregate computations. */ @transient - lazy val resultExpressions = aggregateExpressions.map { agg => + private[this] lazy val resultExpressions = aggregateExpressions.map { agg => agg.transform { case e: Expression if resultMap.contains(e) => resultMap(e) } } - def execute() = attachTree(this, "execute") { + override def execute() = attachTree(this, "execute") { if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() From 55dfd5dcdbf3a9bfddb2108c8325bda3100eb33d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Apr 2014 18:39:18 -0700 Subject: [PATCH 2/4] Removed the default eval implementation from Expression, and added a bunch of override's in classes I touched. It is more robust to not provide a default implementation for Expression's. Author: Reynold Xin Closes #350 from rxin/eval-default and squashes the following commits: 0a83b8f [Reynold Xin] Removed the default eval implementation from Expression, and added a bunch of override's in classes I touched. --- .../sql/catalyst/analysis/unresolved.scala | 52 ++++++++++++------- .../sql/catalyst/expressions/Expression.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 11 +++- .../sql/catalyst/expressions/aggregates.scala | 8 +++ .../expressions/namedExpressions.scala | 21 +++++--- .../plans/physical/partitioning.scala | 32 ++++++++---- .../ExpressionEvaluationSuite.scala | 5 +- .../optimizer/ConstantFoldingSuite.scala | 2 +- 8 files changed, 89 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 41e9bcef3cd7f..d629172a7426e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.{errors, trees} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BaseRelation import org.apache.spark.sql.catalyst.trees.TreeNode @@ -36,7 +37,7 @@ case class UnresolvedRelation( databaseName: Option[String], tableName: String, alias: Option[String] = None) extends BaseRelation { - def output = Nil + override def output = Nil override lazy val resolved = false } @@ -44,26 +45,33 @@ case class UnresolvedRelation( * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - def exprId = throw new UnresolvedException(this, "exprId") - def dataType = throw new UnresolvedException(this, "dataType") - def nullable = throw new UnresolvedException(this, "nullable") - def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def exprId = throw new UnresolvedException(this, "exprId") + override def dataType = throw new UnresolvedException(this, "dataType") + override def nullable = throw new UnresolvedException(this, "nullable") + override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - def newInstance = this - def withQualifiers(newQualifiers: Seq[String]) = this + override def newInstance = this + override def withQualifiers(newQualifiers: Seq[String]) = this + + // Unresolved attributes are transient at compile time and don't get evaluated during execution. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" } case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - def exprId = throw new UnresolvedException(this, "exprId") - def dataType = throw new UnresolvedException(this, "dataType") + override def dataType = throw new UnresolvedException(this, "dataType") override def foldable = throw new UnresolvedException(this, "foldable") - def nullable = throw new UnresolvedException(this, "nullable") - def qualifiers = throw new UnresolvedException(this, "qualifiers") - def references = children.flatMap(_.references).toSet + override def nullable = throw new UnresolvedException(this, "nullable") + override def references = children.flatMap(_.references).toSet override lazy val resolved = false + + // Unresolved functions are transient at compile time and don't get evaluated during execution. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString = s"'$name(${children.mkString(",")})" } @@ -79,15 +87,15 @@ case class Star( mapFunction: Attribute => Expression = identity[Attribute]) extends Attribute with trees.LeafNode[Expression] { - def name = throw new UnresolvedException(this, "exprId") - def exprId = throw new UnresolvedException(this, "exprId") - def dataType = throw new UnresolvedException(this, "dataType") - def nullable = throw new UnresolvedException(this, "nullable") - def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def name = throw new UnresolvedException(this, "exprId") + override def exprId = throw new UnresolvedException(this, "exprId") + override def dataType = throw new UnresolvedException(this, "dataType") + override def nullable = throw new UnresolvedException(this, "nullable") + override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - def newInstance = this - def withQualifiers(newQualifiers: Seq[String]) = this + override def newInstance = this + override def withQualifiers(newQualifiers: Seq[String]) = this def expand(input: Seq[Attribute]): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { @@ -104,5 +112,9 @@ case class Star( mappedAttributes } + // Star gets expanded at runtime so we never evaluate a Star. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString = table.map(_ + ".").getOrElse("") + "*" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f190bd0cca375..8a1db8e796816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -50,8 +50,7 @@ abstract class Expression extends TreeNode[Expression] { def references: Set[Attribute] /** Returns the result of evaluating this expression on a given input Row */ - def eval(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + def eval(input: Row = null): EvaluatedType /** * Returns `true` if this expression and all its children have been resolved to a specific schema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d5d93778f4b8d..08b2f11d20f5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.errors.TreeNodeException + abstract sealed class SortDirection case object Ascending extends SortDirection case object Descending extends SortDirection @@ -26,7 +28,12 @@ case object Descending extends SortDirection * transformations over expression will descend into its child. */ case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { - def dataType = child.dataType - def nullable = child.nullable + override def dataType = child.dataType + override def nullable = child.nullable + + // SortOrder itself is never evaluated. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5edcea14278c7..b152f95f96c70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException abstract class AggregateExpression extends Expression { self: Product => @@ -28,6 +29,13 @@ abstract class AggregateExpression extends Expression { * of input rows/ */ def newInstance(): AggregateFunction + + /** + * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are + * replaced with a physical aggregate operator at runtime. + */ + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eb4bc8e755284..a8145c37c20fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ object NamedExpression { @@ -58,9 +59,9 @@ abstract class Attribute extends NamedExpression { def withQualifiers(newQualifiers: Seq[String]): Attribute - def references = Set(this) def toAttribute = this def newInstance: Attribute + override def references = Set(this) } /** @@ -77,15 +78,15 @@ case class Alias(child: Expression, name: String) (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends NamedExpression with trees.UnaryNode[Expression] { - type EvaluatedType = Any + override type EvaluatedType = Any override def eval(input: Row) = child.eval(input) - def dataType = child.dataType - def nullable = child.nullable - def references = child.references + override def dataType = child.dataType + override def nullable = child.nullable + override def references = child.references - def toAttribute = { + override def toAttribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) } else { @@ -127,7 +128,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea h } - def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -143,7 +144,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]) = { if (newQualifiers == qualifiers) { this } else { @@ -151,5 +152,9 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea } } + // Unresolved attributes are transient at compile time and don't get evaluated during execution. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString: String = s"$name#${exprId.id}$typeSuffix" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 8893744eb2e7a..ffb3a92f8f340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} import org.apache.spark.sql.catalyst.types.IntegerType /** @@ -139,12 +140,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - def children = expressions - def references = expressions.flatMap(_.references).toSet - def nullable = false - def dataType = IntegerType + override def children = expressions + override def references = expressions.flatMap(_.references).toSet + override def nullable = false + override def dataType = IntegerType - lazy val clusteringSet = expressions.toSet + private[this] lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -158,6 +159,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case h: HashPartitioning if h == this => true case _ => false } + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** @@ -168,17 +172,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * partition. * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows * that are in between `min` and `max` in this `ordering` will reside in this partition. + * + * This class extends expression primarily so that transformations over expression will descend + * into its child. */ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - def children = ordering - def references = ordering.flatMap(_.references).toSet - def nullable = false - def dataType = IntegerType + override def children = ordering + override def references = ordering.flatMap(_.references).toSet + override def nullable = false + override def dataType = IntegerType - lazy val clusteringSet = ordering.map(_.child).toSet + private[this] lazy val clusteringSet = ordering.map(_.child).toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -195,4 +202,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case r: RangePartitioning if r == this => true case _ => false } + + override def eval(input: Row): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 92987405aa313..31be6c4ef1b0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -100,7 +100,10 @@ class ExpressionEvaluationSuite extends FunSuite { (null, false, null) :: (null, null, null) :: Nil) - def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) { + def booleanLogicTest( + name: String, + op: (Expression, Expression) => Expression, + truthTable: Seq[(Any, Any, Any)]) { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 2ab14f48ccc8a..20dfba847790c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType} // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ From 31e6fff03730bb915a836d77dcd43d098afd1dbd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Apr 2014 18:40:08 -0700 Subject: [PATCH 3/4] Added eval for Rand (without any support for user-defined seed). Author: Reynold Xin Closes #349 from rxin/rand and squashes the following commits: fd11322 [Reynold Xin] Added eval for Rand (without any support for user-defined seed). --- .../spark/sql/catalyst/expressions/Rand.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index 0bde621602944..38f836f0a1a0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -17,11 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Random import org.apache.spark.sql.catalyst.types.DoubleType + case object Rand extends LeafExpression { - def dataType = DoubleType - def nullable = false - def references = Set.empty + override def dataType = DoubleType + override def nullable = false + override def references = Set.empty + + private[this] lazy val rand = new Random + + override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def toString = "RAND()" } From f27e56aa612538188a8550fe72ee20b8b13304d7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Apr 2014 19:28:24 -0700 Subject: [PATCH 4/4] Change timestamp cast semantics. When cast to numeric types, return the unix time in seconds (instead of millis). @marmbrus @chenghao-intel Author: Reynold Xin Closes #352 from rxin/timestamp-cast and squashes the following commits: 18aacd3 [Reynold Xin] Fixed precision for double. 2adb235 [Reynold Xin] Change timestamp cast semantics. When cast to numeric types, return the unix time in seconds (instead of millis). --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 23 ++++++++++------ .../ExpressionEvaluationSuite.scala | 27 ++++++++++++++++--- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2d62e4cbbce01..987befe8e22ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -104,7 +104,7 @@ package object dsl { implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { - def expr: Expression = Literal(s) + override def expr: Expression = Literal(s) def attr = analysis.UnresolvedAttribute(s) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 89226999ca005..17118499d0c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -87,7 +87,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private def decimalToTimestamp(d: BigDecimal) = { val seconds = d.longValue() - val bd = (d - seconds) * (1000000000) + val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() // Convert to millis @@ -96,18 +96,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // remaining fractional portion as nanos t.setNanos(nanos) - t } - private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000) + // Timestamp to long, converting milliseconds to seconds + private def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + + private def timestampToDouble(ts: Timestamp) = { + // First part is the seconds since the beginning of time, followed by nanosecs. + ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 + } def castToLong: Any => Any = child.dataType match { case StringType => nullOrCast[String](_, s => try s.toLong catch { case _: NumberFormatException => null }) case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) case DecimalType => nullOrCast[BigDecimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -117,7 +122,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case _: NumberFormatException => null }) case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt) case DecimalType => nullOrCast[BigDecimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -127,7 +132,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case _: NumberFormatException => null }) case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) case DecimalType => nullOrCast[BigDecimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -137,7 +142,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case _: NumberFormatException => null }) case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) case DecimalType => nullOrCast[BigDecimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } @@ -147,7 +152,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case _: NumberFormatException => null }) case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) - case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case TimestampType => + // Note that we lose precision here. + nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 31be6c4ef1b0b..888a19d79f7e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -201,7 +201,7 @@ class ExpressionEvaluationSuite extends FunSuite { val sts = "1970-01-01 00:00:01.0" val ts = Timestamp.valueOf(sts) - + checkEvaluation("abdef" cast StringType, "abdef") checkEvaluation("abdef" cast DecimalType, null) checkEvaluation("abdef" cast TimestampType, null) @@ -209,7 +209,6 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1) - checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) @@ -240,12 +239,34 @@ class ExpressionEvaluationSuite extends FunSuite { intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} } - + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) checkEvaluation(Literal("ab") < Literal("abc"), true) checkEvaluation(Literal(ts1) < Literal(ts2), true) } + + test("timestamp casting") { + val millis = 15 * 1000 + 2 + val ts = new Timestamp(millis) + val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part + checkEvaluation(Cast(ts, ShortType), 15) + checkEvaluation(Cast(ts, IntegerType), 15) + checkEvaluation(Cast(ts, LongType), 15) + checkEvaluation(Cast(ts, FloatType), 15.002f) + checkEvaluation(Cast(ts, DoubleType), 15.002) + checkEvaluation(Cast(Cast(ts, ShortType), TimestampType), ts1) + checkEvaluation(Cast(Cast(ts, IntegerType), TimestampType), ts1) + checkEvaluation(Cast(Cast(ts, LongType), TimestampType), ts1) + checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), + millis.toFloat / 1000) + checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), + millis.toDouble / 1000) + checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + + // A test for higher precision than millis + checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) + } }