From ce4385efdfc7cb64068cae43dd41b8689bd5009b Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 31 Mar 2014 15:05:10 +0800 Subject: [PATCH] Add TimestampType Support --- .../spark/sql/catalyst/dsl/package.scala | 37 ++++ .../spark/sql/catalyst/expressions/Cast.scala | 194 ++++++++++++++---- .../sql/catalyst/expressions/Expression.scala | 29 +++ .../sql/catalyst/expressions/literals.scala | 5 + .../sql/catalyst/expressions/predicates.scala | 59 +----- .../spark/sql/catalyst/types/dataTypes.scala | 12 ++ .../ExpressionEvaluationSuite.scala | 45 ++++ .../org/apache/spark/sql/hive/HiveQl.scala | 11 +- 8 files changed, 298 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 44abe671c07a4..98bf6d4b24962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.sql.Timestamp + import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -72,6 +74,7 @@ package object dsl { def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) + def cast(to: DataType) = Cast(expr, to) def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) @@ -84,15 +87,22 @@ package object dsl { def expr = e } + implicit def booleanToLiteral(b: Boolean) = Literal(b) + implicit def byteToLiteral(b: Byte) = Literal(b) + implicit def shortToLiteral(s: Short) = Literal(s) implicit def intToLiteral(i: Int) = Literal(i) implicit def longToLiteral(l: Long) = Literal(l) implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def timestampToLiteral(t: Timestamp) = Literal(t) + implicit def bytesToLiteral(a: Array[Byte]) = Literal(a) implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { def expr: Expression = Literal(s) def attr = analysis.UnresolvedAttribute(s) @@ -103,11 +113,38 @@ package object dsl { def expr = attr def attr = analysis.UnresolvedAttribute(s) + /** Creates a new typed attributes of type boolean */ + def boolean = AttributeReference(s, BooleanType, nullable = false)() + + /** Creates a new typed attributes of type byte */ + def byte = AttributeReference(s, ByteType, nullable = false)() + + /** Creates a new typed attributes of type short */ + def short = AttributeReference(s, ShortType, nullable = false)() + /** Creates a new typed attributes of type int */ def int = AttributeReference(s, IntegerType, nullable = false)() + /** Creates a new typed attributes of type long */ + def long = AttributeReference(s, LongType, nullable = false)() + + /** Creates a new typed attributes of type float */ + def float = AttributeReference(s, FloatType, nullable = false)() + + /** Creates a new typed attributes of type double */ + def double = AttributeReference(s, DoubleType, nullable = false)() + /** Creates a new typed attributes of type string */ def string = AttributeReference(s, StringType, nullable = false)() + + /** Creates a new typed attributes of type decimal */ + def decimal = AttributeReference(s, DecimalType, nullable = false)() + + /** Creates a new typed attributes of type timestamp */ + def timestamp = AttributeReference(s, TimestampType, nullable = false)() + + /** Creates a new typed attributes of type binary */ + def binary = AttributeReference(s, BinaryType, nullable = false)() } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c26fc3d0f305f..7f638433a3cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp +import java.lang.{NumberFormatException => NFE} + import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ @@ -26,52 +29,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any + + def nullOrCast[T, B](a: Any, func: T => B): B = if(a == null) { + null.asInstanceOf[B] + } else { + func(a.asInstanceOf[T]) + } + + // UDFToString + def castToString: Any => String = child.dataType match { + case BinaryType => nullOrCast[Array[Byte], String](_, new String(_, "UTF-8")) + case _ => nullOrCast[Any, String](_, _.toString) + } + + // BinaryConverter + def castToBinary: Any => Array[Byte] = child.dataType match { + case StringType => nullOrCast[String, Array[Byte]](_, _.getBytes("UTF-8")) + } - lazy val castingFunction: Any => Any = (child.dataType, dataType) match { - case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]]) - case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes - case (_, StringType) => a: Any => a.toString - case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt) - case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble) - case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat) - case (StringType, LongType) => a: Any => castOrNull(a, _.toLong) - case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort) - case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte) - case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_)) - case (BooleanType, ByteType) => { - case null => null - case true => 1.toByte - case false => 0.toByte - } - case (dt, IntegerType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a) - case (dt, DoubleType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a) - case (dt, FloatType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a) - case (dt, LongType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a) - case (dt, ShortType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort - case (dt, ByteType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte - case (dt, DecimalType) => - a: Any => - BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)) + // UDFToBoolean + def castToBoolean: Any => Boolean = child.dataType match { + case StringType => nullOrCast[String, Boolean](_, _.length() != 0) + case TimestampType => nullOrCast[Timestamp, Boolean](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) + case LongType => nullOrCast[Long, Boolean](_, _ != 0) + case IntegerType => nullOrCast[Int, Boolean](_, _ != 0) + case ShortType => nullOrCast[Short, Boolean](_, _ != 0) + case ByteType => nullOrCast[Byte, Boolean](_, _ != 0) + case DecimalType => nullOrCast[BigDecimal, Boolean](_, _ != 0) + case DoubleType => nullOrCast[Double, Boolean](_, _ != 0) + case FloatType => nullOrCast[Float, Boolean](_, _ != 0) + } + + // TimestampConverter + def castToTimestamp: Any => Timestamp = child.dataType match { + case StringType => nullOrCast[String, Timestamp](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf("."); + var n = s + if (periodIdx != -1) { + if (n.length() - periodIdx > 9) { + n = n.substring(0, periodIdx + 10) + } + } + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} + }) + case BooleanType => nullOrCast[Boolean, Timestamp](_, b => new Timestamp((if(b) 1 else 0) * 1000)) + case LongType => nullOrCast[Long, Timestamp](_, l => new Timestamp(l * 1000)) + case IntegerType => nullOrCast[Int, Timestamp](_, i => new Timestamp(i * 1000)) + case ShortType => nullOrCast[Short, Timestamp](_, s => new Timestamp(s * 1000)) + case ByteType => nullOrCast[Byte, Timestamp](_, b => new Timestamp(b * 1000)) + // TimestampWritable.decimalToTimestamp + case DecimalType => nullOrCast[BigDecimal, Timestamp](_, d => decimalToTimestamp(d)) + // TimestampWritable.doubleToTimestamp + case DoubleType => nullOrCast[Double, Timestamp](_, d => decimalToTimestamp(d)) + // TimestampWritable.floatToTimestamp + case FloatType => nullOrCast[Float, Timestamp](_, f => decimalToTimestamp(f)) } - @inline - protected def castOrNull[A](a: Any, f: String => A) = - try f(a.asInstanceOf[String]) catch { - case _: java.lang.NumberFormatException => null - } + private def decimalToTimestamp(d: BigDecimal) = { + val seconds = d.longValue() + val bd = (d - seconds) * (1000000000) + val nanos = bd.intValue() + + // Convert to millis + val millis = seconds * 1000 + val t = new Timestamp(millis) + + // remaining fractional portion as nanos + t.setNanos(nanos) + + t + } + + private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000) + + def castToLong: Any => Long = child.dataType match { + case StringType => nullOrCast[String, Long](_, s => try s.toLong catch { + case _: NFE => null.asInstanceOf[Long] + }) + case BooleanType => nullOrCast[Boolean, Long](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, Long](_, t => timestampToDouble(t).toLong) + case DecimalType => nullOrCast[BigDecimal, Long](_, _.toLong) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + def castToInt: Any => Int = child.dataType match { + case StringType => nullOrCast[String, Int](_, s => try s.toInt catch { + case _: NFE => null.asInstanceOf[Int] + }) + case BooleanType => nullOrCast[Boolean, Int](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, Int](_, t => timestampToDouble(t).toInt) + case DecimalType => nullOrCast[BigDecimal, Int](_, _.toInt) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + def castToShort: Any => Short = child.dataType match { + case StringType => nullOrCast[String, Short](_, s => try s.toShort catch { + case _: NFE => null.asInstanceOf[Short] + }) + case BooleanType => nullOrCast[Boolean, Short](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, Short](_, t => timestampToDouble(t).toShort) + case DecimalType => nullOrCast[BigDecimal, Short](_, _.toShort) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + def castToByte: Any => Byte = child.dataType match { + case StringType => nullOrCast[String, Byte](_, s => try s.toByte catch { + case _: NFE => null.asInstanceOf[Byte] + }) + case BooleanType => nullOrCast[Boolean, Byte](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, Byte](_, t => timestampToDouble(t).toByte) + case DecimalType => nullOrCast[BigDecimal, Byte](_, _.toByte) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + def castToDecimal: Any => BigDecimal = child.dataType match { + case StringType => nullOrCast[String, BigDecimal](_, s => try s.toDouble catch { + case _: NFE => null.asInstanceOf[BigDecimal] + }) + case BooleanType => nullOrCast[Boolean, BigDecimal](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, BigDecimal](_, t => timestampToDouble(t)) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + def castToDouble: Any => Double = child.dataType match { + case StringType => nullOrCast[String, Double](_, s => try s.toInt catch { + case _: NFE => null.asInstanceOf[Int] + }) + case BooleanType => nullOrCast[Boolean, Double](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, Double](_, t => timestampToDouble(t)) + case DecimalType => nullOrCast[BigDecimal, Double](_, _.toDouble) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + def castToFloat: Any => Float = child.dataType match { + case StringType => nullOrCast[String, Float](_, s => try s.toInt catch { + case _: NFE => null.asInstanceOf[Int] + }) + case BooleanType => nullOrCast[Boolean, Float](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp, Float](_, t => timestampToDouble(t).toFloat) + case DecimalType => nullOrCast[BigDecimal, Float](_, _.toFloat) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + } + + def cast: Any => Any = (child.dataType, dataType) match { + case (_, StringType) => castToString + case (_, BinaryType) => castToBinary + case (_, DecimalType) => castToDecimal + case (_, TimestampType) => castToTimestamp + case (_, BooleanType) => castToBoolean + case (_, ByteType) => castToByte + case (_, ShortType) => castToShort + case (_, IntegerType) => castToInt + case (_, FloatType) => castToFloat + case (_, LongType) => castToLong + case (_, DoubleType) => castToDouble + } override def apply(input: Row): Any = { val evaluated = child.apply(input) if (evaluated == null) { null } else { - castingFunction(evaluated) + if(child.dataType == dataType) evaluated else cast(evaluated) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 81fd160e00ca1..f786701b82220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType} +import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType, NativeType} abstract class Expression extends TreeNode[Expression] { self: Product => @@ -170,6 +171,34 @@ abstract class Expression extends TreeNode[Expression] { } } } + + @inline + protected final def c2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Ordering[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case i: NativeType => + f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean]( + i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case other => sys.error(s"Type $other does not support ordered operations") + } + } + } + } } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index b82a12e0f754e..d879b2b5e8ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types._ object Literal { @@ -29,6 +31,9 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) + case d: BigDecimal => Literal(d, DecimalType) + case t: Timestamp => Literal(t, TimestampType) + case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 02fedd16b8d4b..b74809e5ca67d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType} object InterpretedPredicate { def apply(expression: Expression): (Row => Boolean) = { @@ -123,70 +124,22 @@ case class Equals(left: Expression, right: Expression) extends BinaryComparison case class LessThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] < r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.lt(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.lt(_, _)) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<=" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] <= r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.lteq(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.lteq(_, _)) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] > r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.gt(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.gt(_, _)) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">=" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] >= r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.gteq(_, _)) - } - } + override def apply(input: Row): Any = c2(input, left, right, _.gteq(_, _)) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 7a45d1a1b8195..cdeb01a9656f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.types +import java.sql.Timestamp + import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.catalyst.expressions.Expression @@ -51,6 +53,16 @@ case object BooleanType extends NativeType { val ordering = implicitly[Ordering[JvmType]] } +case object TimestampType extends NativeType { + type JvmType = Timestamp + + @transient lazy val tag = typeTag[JvmType] + + val ordering = new Ordering[JvmType] { + def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) + } +} + abstract class NumericType extends NativeType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 52a205be3e9f4..aa50b7a7634f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ @@ -191,5 +193,48 @@ class ExpressionEvaluationSuite extends FunSuite { evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) } } + + test("data type casting") { + + val sts = "1970-01-01 00:00:01.0" + val ts = Timestamp.valueOf(sts) + + checkEvaluation("abdef" cast StringType, "abdef") + checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast TimestampType, null) + checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + + checkEvaluation(Literal(1) cast LongType, 1) + checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1) + checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) + + checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) + checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) + + checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") + + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5) + checkEvaluation(Literal(true) cast IntegerType, 1) + checkEvaluation(Literal(false) cast IntegerType, 0) + checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) + checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) + + intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} + } + + test("timestamp") { + val ts1 = new Timestamp(12) + val ts2 = new Timestamp(123) + checkEvaluation(Literal("ab") < Literal("abc"), true) + checkEvaluation(Literal(ts1) < Literal(ts2), true) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b2b03bc790fcc..bb4f157b7e3d6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -300,14 +300,17 @@ object HiveQl { } protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_BIGINT", Nil) => IntegerType + case Token("TOK_DECIMAL", Nil) => DecimalType + case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => IntegerType - case Token("TOK_SMALLINT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => ByteType + case Token("TOK_SMALLINT", Nil) => ShortType case Token("TOK_BOOLEAN", Nil) => BooleanType case Token("TOK_STRING", Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_TIMESTAMP", Nil) => TimestampType + case Token("TOK_BINARY", Nil) => BinaryType case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) =>