Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1360] Add Timestamp Support for SQL #275

Closed
wants to merge 12 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
Expand Down Expand Up @@ -54,14 +56,15 @@ object ScalaReflection {
val TypeRef(_, _, Seq(keyType, valueType)) = t
MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
case t if t <:< typeOf[Timestamp] => TimestampType
case t if t <:< typeOf[BigDecimal] => DecimalType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.DoubleTpe => DoubleType
case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< typeOf[BigDecimal] => DecimalType
}

implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
Expand Down Expand Up @@ -72,6 +74,7 @@ package object dsl {

def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
def cast(to: DataType) = Cast(expr, to)

def asc = SortOrder(expr, Ascending)
def desc = SortOrder(expr, Descending)
Expand All @@ -84,15 +87,22 @@ package object dsl {
def expr = e
}

implicit def booleanToLiteral(b: Boolean) = Literal(b)
implicit def byteToLiteral(b: Byte) = Literal(b)
implicit def shortToLiteral(s: Short) = Literal(s)
implicit def intToLiteral(i: Int) = Literal(i)
implicit def longToLiteral(l: Long) = Literal(l)
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)

implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
def expr: Expression = Literal(s)
def attr = analysis.UnresolvedAttribute(s)
Expand All @@ -103,11 +113,38 @@ package object dsl {
def expr = attr
def attr = analysis.UnresolvedAttribute(s)

/** Creates a new typed attributes of type int */
/** Creates a new AttributeReference of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()

/** Creates a new AttributeReference of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()

/** Creates a new AttributeReference of type short */
def short = AttributeReference(s, ShortType, nullable = false)()

/** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()

/** Creates a new typed attributes of type string */
/** Creates a new AttributeReference of type long */
def long = AttributeReference(s, LongType, nullable = false)()

/** Creates a new AttributeReference of type float */
def float = AttributeReference(s, FloatType, nullable = false)()

/** Creates a new AttributeReference of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()

/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()

/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()

/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.types._

/** Cast the child expression to the target data type. */
Expand All @@ -26,52 +28,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def toString = s"CAST($child, $dataType)"

type EvaluatedType = Any

def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
null
} else {
func(a.asInstanceOf[T])
}

lazy val castingFunction: Any => Any = (child.dataType, dataType) match {
case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]])
case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes
case (_, StringType) => a: Any => a.toString
case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt)
case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble)
case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat)
case (StringType, LongType) => a: Any => castOrNull(a, _.toLong)
case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort)
case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte)
case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_))
case (BooleanType, ByteType) => {
case null => null
case true => 1.toByte
case false => 0.toByte
}
case (dt, IntegerType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a)
case (dt, DoubleType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)
case (dt, FloatType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a)
case (dt, LongType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a)
case (dt, ShortType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort
case (dt, ByteType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte
case (dt, DecimalType) =>
a: Any =>
BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a))
// UDFToString
def castToString: Any => Any = child.dataType match {
case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
case _ => nullOrCast[Any](_, _.toString)
}

// BinaryConverter
def castToBinary: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
}

@inline
protected def castOrNull[A](a: Any, f: String => A) =
try f(a.asInstanceOf[String]) catch {
case _: java.lang.NumberFormatException => null
}
// UDFToBoolean
def castToBoolean: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, _.length() != 0)
case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
case LongType => nullOrCast[Long](_, _ != 0)
case IntegerType => nullOrCast[Int](_, _ != 0)
case ShortType => nullOrCast[Short](_, _ != 0)
case ByteType => nullOrCast[Byte](_, _ != 0)
case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
case DoubleType => nullOrCast[Double](_, _ != 0)
case FloatType => nullOrCast[Float](_, _ != 0)
}

// TimestampConverter
def castToTimestamp: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => {
// Throw away extra if more than 9 decimal places
val periodIdx = s.indexOf(".");
var n = s
if (periodIdx != -1) {
if (n.length() - periodIdx > 9) {
n = n.substring(0, periodIdx + 10)
}
}
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
})
case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
// TimestampWritable.decimalToTimestamp
case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
// TimestampWritable.floatToTimestamp
case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
}

private def decimalToTimestamp(d: BigDecimal) = {
val seconds = d.longValue()
val bd = (d - seconds) * (1000000000)
val nanos = bd.intValue()

// Convert to millis
val millis = seconds * 1000
val t = new Timestamp(millis)

// remaining fractional portion as nanos
t.setNanos(nanos)

t
}

private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)

def castToLong: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}

def castToInt: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}

def castToShort: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}

def castToByte: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

def castToDecimal: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
}

def castToDouble: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}

def castToFloat: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}

def cast: Any => Any = dataType match {
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
}

override def apply(input: Row): Any = {
val evaluated = child.apply(input)
if (evaluated == null) {
null
} else {
castingFunction(evaluated)
cast(evaluated)
}
}
}
Loading