Skip to content

Commit

Permalink
Add TimestampType Support
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Apr 3, 2014
1 parent 47ebea5 commit ce4385e
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
Expand Down Expand Up @@ -72,6 +74,7 @@ package object dsl {

def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
def cast(to: DataType) = Cast(expr, to)

def asc = SortOrder(expr, Ascending)
def desc = SortOrder(expr, Descending)
Expand All @@ -84,15 +87,22 @@ package object dsl {
def expr = e
}

implicit def booleanToLiteral(b: Boolean) = Literal(b)
implicit def byteToLiteral(b: Byte) = Literal(b)
implicit def shortToLiteral(s: Short) = Literal(s)
implicit def intToLiteral(i: Int) = Literal(i)
implicit def longToLiteral(l: Long) = Literal(l)
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def bytesToLiteral(a: Array[Byte]) = Literal(a)

implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
def expr: Expression = Literal(s)
def attr = analysis.UnresolvedAttribute(s)
Expand All @@ -103,11 +113,38 @@ package object dsl {
def expr = attr
def attr = analysis.UnresolvedAttribute(s)

/** Creates a new typed attributes of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()

/** Creates a new typed attributes of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()

/** Creates a new typed attributes of type short */
def short = AttributeReference(s, ShortType, nullable = false)()

/** Creates a new typed attributes of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()

/** Creates a new typed attributes of type long */
def long = AttributeReference(s, LongType, nullable = false)()

/** Creates a new typed attributes of type float */
def float = AttributeReference(s, FloatType, nullable = false)()

/** Creates a new typed attributes of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()

/** Creates a new typed attributes of type string */
def string = AttributeReference(s, StringType, nullable = false)()

/** Creates a new typed attributes of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()

/** Creates a new typed attributes of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()

/** Creates a new typed attributes of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp
import java.lang.{NumberFormatException => NFE}

import org.apache.spark.sql.catalyst.types._

/** Cast the child expression to the target data type. */
Expand All @@ -26,52 +29,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def toString = s"CAST($child, $dataType)"

type EvaluatedType = Any

def nullOrCast[T, B](a: Any, func: T => B): B = if(a == null) {
null.asInstanceOf[B]
} else {
func(a.asInstanceOf[T])
}

// UDFToString
def castToString: Any => String = child.dataType match {
case BinaryType => nullOrCast[Array[Byte], String](_, new String(_, "UTF-8"))
case _ => nullOrCast[Any, String](_, _.toString)
}

// BinaryConverter
def castToBinary: Any => Array[Byte] = child.dataType match {
case StringType => nullOrCast[String, Array[Byte]](_, _.getBytes("UTF-8"))
}

lazy val castingFunction: Any => Any = (child.dataType, dataType) match {
case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]])
case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes
case (_, StringType) => a: Any => a.toString
case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt)
case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble)
case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat)
case (StringType, LongType) => a: Any => castOrNull(a, _.toLong)
case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort)
case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte)
case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_))
case (BooleanType, ByteType) => {
case null => null
case true => 1.toByte
case false => 0.toByte
}
case (dt, IntegerType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a)
case (dt, DoubleType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)
case (dt, FloatType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a)
case (dt, LongType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a)
case (dt, ShortType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort
case (dt, ByteType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte
case (dt, DecimalType) =>
a: Any =>
BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a))
// UDFToBoolean
def castToBoolean: Any => Boolean = child.dataType match {
case StringType => nullOrCast[String, Boolean](_, _.length() != 0)
case TimestampType => nullOrCast[Timestamp, Boolean](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
case LongType => nullOrCast[Long, Boolean](_, _ != 0)
case IntegerType => nullOrCast[Int, Boolean](_, _ != 0)
case ShortType => nullOrCast[Short, Boolean](_, _ != 0)
case ByteType => nullOrCast[Byte, Boolean](_, _ != 0)
case DecimalType => nullOrCast[BigDecimal, Boolean](_, _ != 0)
case DoubleType => nullOrCast[Double, Boolean](_, _ != 0)
case FloatType => nullOrCast[Float, Boolean](_, _ != 0)
}

// TimestampConverter
def castToTimestamp: Any => Timestamp = child.dataType match {
case StringType => nullOrCast[String, Timestamp](_, s => {
// Throw away extra if more than 9 decimal places
val periodIdx = s.indexOf(".");
var n = s
if (periodIdx != -1) {
if (n.length() - periodIdx > 9) {
n = n.substring(0, periodIdx + 10)
}
}
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
})
case BooleanType => nullOrCast[Boolean, Timestamp](_, b => new Timestamp((if(b) 1 else 0) * 1000))
case LongType => nullOrCast[Long, Timestamp](_, l => new Timestamp(l * 1000))
case IntegerType => nullOrCast[Int, Timestamp](_, i => new Timestamp(i * 1000))
case ShortType => nullOrCast[Short, Timestamp](_, s => new Timestamp(s * 1000))
case ByteType => nullOrCast[Byte, Timestamp](_, b => new Timestamp(b * 1000))
// TimestampWritable.decimalToTimestamp
case DecimalType => nullOrCast[BigDecimal, Timestamp](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
case DoubleType => nullOrCast[Double, Timestamp](_, d => decimalToTimestamp(d))
// TimestampWritable.floatToTimestamp
case FloatType => nullOrCast[Float, Timestamp](_, f => decimalToTimestamp(f))
}

@inline
protected def castOrNull[A](a: Any, f: String => A) =
try f(a.asInstanceOf[String]) catch {
case _: java.lang.NumberFormatException => null
}
private def decimalToTimestamp(d: BigDecimal) = {
val seconds = d.longValue()
val bd = (d - seconds) * (1000000000)
val nanos = bd.intValue()

// Convert to millis
val millis = seconds * 1000
val t = new Timestamp(millis)

// remaining fractional portion as nanos
t.setNanos(nanos)

t
}

private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)

def castToLong: Any => Long = child.dataType match {
case StringType => nullOrCast[String, Long](_, s => try s.toLong catch {
case _: NFE => null.asInstanceOf[Long]
})
case BooleanType => nullOrCast[Boolean, Long](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Long](_, t => timestampToDouble(t).toLong)
case DecimalType => nullOrCast[BigDecimal, Long](_, _.toLong)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}

def castToInt: Any => Int = child.dataType match {
case StringType => nullOrCast[String, Int](_, s => try s.toInt catch {
case _: NFE => null.asInstanceOf[Int]
})
case BooleanType => nullOrCast[Boolean, Int](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Int](_, t => timestampToDouble(t).toInt)
case DecimalType => nullOrCast[BigDecimal, Int](_, _.toInt)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}

def castToShort: Any => Short = child.dataType match {
case StringType => nullOrCast[String, Short](_, s => try s.toShort catch {
case _: NFE => null.asInstanceOf[Short]
})
case BooleanType => nullOrCast[Boolean, Short](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Short](_, t => timestampToDouble(t).toShort)
case DecimalType => nullOrCast[BigDecimal, Short](_, _.toShort)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}

def castToByte: Any => Byte = child.dataType match {
case StringType => nullOrCast[String, Byte](_, s => try s.toByte catch {
case _: NFE => null.asInstanceOf[Byte]
})
case BooleanType => nullOrCast[Boolean, Byte](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Byte](_, t => timestampToDouble(t).toByte)
case DecimalType => nullOrCast[BigDecimal, Byte](_, _.toByte)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

def castToDecimal: Any => BigDecimal = child.dataType match {
case StringType => nullOrCast[String, BigDecimal](_, s => try s.toDouble catch {
case _: NFE => null.asInstanceOf[BigDecimal]
})
case BooleanType => nullOrCast[Boolean, BigDecimal](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, BigDecimal](_, t => timestampToDouble(t))
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}

def castToDouble: Any => Double = child.dataType match {
case StringType => nullOrCast[String, Double](_, s => try s.toInt catch {
case _: NFE => null.asInstanceOf[Int]
})
case BooleanType => nullOrCast[Boolean, Double](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Double](_, t => timestampToDouble(t))
case DecimalType => nullOrCast[BigDecimal, Double](_, _.toDouble)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}

def castToFloat: Any => Float = child.dataType match {
case StringType => nullOrCast[String, Float](_, s => try s.toInt catch {
case _: NFE => null.asInstanceOf[Int]
})
case BooleanType => nullOrCast[Boolean, Float](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Float](_, t => timestampToDouble(t).toFloat)
case DecimalType => nullOrCast[BigDecimal, Float](_, _.toFloat)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}

def cast: Any => Any = (child.dataType, dataType) match {
case (_, StringType) => castToString
case (_, BinaryType) => castToBinary
case (_, DecimalType) => castToDecimal
case (_, TimestampType) => castToTimestamp
case (_, BooleanType) => castToBoolean
case (_, ByteType) => castToByte
case (_, ShortType) => castToShort
case (_, IntegerType) => castToInt
case (_, FloatType) => castToFloat
case (_, LongType) => castToLong
case (_, DoubleType) => castToDouble
}

override def apply(input: Row): Any = {
val evaluated = child.apply(input)
if (evaluated == null) {
null
} else {
castingFunction(evaluated)
if(child.dataType == dataType) evaluated else cast(evaluated)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType}
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType, NativeType}

abstract class Expression extends TreeNode[Expression] {
self: Product =>
Expand Down Expand Up @@ -170,6 +171,34 @@ abstract class Expression extends TreeNode[Expression] {
}
}
}

@inline
protected final def c2(
i: Row,
e1: Expression,
e2: Expression,
f: ((Ordering[Any], Any, Any) => Any)): Any = {
if (e1.dataType != e2.dataType) {
throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
}

val evalE1 = e1.apply(i)
if(evalE1 == null) {
null
} else {
val evalE2 = e2.apply(i)
if (evalE2 == null) {
null
} else {
e1.dataType match {
case i: NativeType =>
f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean](
i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
case other => sys.error(s"Type $other does not support ordered operations")
}
}
}
}
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.types._

object Literal {
Expand All @@ -29,6 +31,9 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(d, DecimalType)
case t: Timestamp => Literal(t, TimestampType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
}
Expand Down
Loading

0 comments on commit ce4385e

Please sign in to comment.