Skip to content

Commit

Permalink
fix bug mapping to 0 (which is supposed to be null) when NumberFormat…
Browse files Browse the repository at this point in the history
…Exception occurs
  • Loading branch information
chenghao-intel committed Apr 3, 2014
1 parent 9cb505c commit 54a0489
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {

type EvaluatedType = Any

def nullOrCast[T, B](a: Any, func: T => B): B = if(a == null) {
null.asInstanceOf[B]
def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
null
} else {
func(a.asInstanceOf[T])
}

// UDFToString
def castToString: Any => String = child.dataType match {
case BinaryType => nullOrCast[Array[Byte], String](_, new String(_, "UTF-8"))
case _ => nullOrCast[Any, String](_, _.toString)
def castToString: Any => Any = child.dataType match {
case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
case _ => nullOrCast[Any](_, _.toString)
}

// BinaryConverter
def castToBinary: Any => Array[Byte] = child.dataType match {
case StringType => nullOrCast[String, Array[Byte]](_, _.getBytes("UTF-8"))
def castToBinary: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
}

// UDFToBoolean
def castToBoolean: Any => Boolean = child.dataType match {
case StringType => nullOrCast[String, Boolean](_, _.length() != 0)
case TimestampType => nullOrCast[Timestamp, Boolean](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
case LongType => nullOrCast[Long, Boolean](_, _ != 0)
case IntegerType => nullOrCast[Int, Boolean](_, _ != 0)
case ShortType => nullOrCast[Short, Boolean](_, _ != 0)
case ByteType => nullOrCast[Byte, Boolean](_, _ != 0)
case DecimalType => nullOrCast[BigDecimal, Boolean](_, _ != 0)
case DoubleType => nullOrCast[Double, Boolean](_, _ != 0)
case FloatType => nullOrCast[Float, Boolean](_, _ != 0)
def castToBoolean: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, _.length() != 0)
case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
case LongType => nullOrCast[Long](_, _ != 0)
case IntegerType => nullOrCast[Int](_, _ != 0)
case ShortType => nullOrCast[Short](_, _ != 0)
case ByteType => nullOrCast[Byte](_, _ != 0)
case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
case DoubleType => nullOrCast[Double](_, _ != 0)
case FloatType => nullOrCast[Float](_, _ != 0)
}

// TimestampConverter
def castToTimestamp: Any => Timestamp = child.dataType match {
case StringType => nullOrCast[String, Timestamp](_, s => {
def castToTimestamp: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => {
// Throw away extra if more than 9 decimal places
val periodIdx = s.indexOf(".");
var n = s
Expand All @@ -72,17 +72,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
})
case BooleanType => nullOrCast[Boolean, Timestamp](_, b => new Timestamp((if(b) 1 else 0) * 1000))
case LongType => nullOrCast[Long, Timestamp](_, l => new Timestamp(l * 1000))
case IntegerType => nullOrCast[Int, Timestamp](_, i => new Timestamp(i * 1000))
case ShortType => nullOrCast[Short, Timestamp](_, s => new Timestamp(s * 1000))
case ByteType => nullOrCast[Byte, Timestamp](_, b => new Timestamp(b * 1000))
case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
// TimestampWritable.decimalToTimestamp
case DecimalType => nullOrCast[BigDecimal, Timestamp](_, d => decimalToTimestamp(d))
case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
case DoubleType => nullOrCast[Double, Timestamp](_, d => decimalToTimestamp(d))
case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
// TimestampWritable.floatToTimestamp
case FloatType => nullOrCast[Float, Timestamp](_, f => decimalToTimestamp(f))
case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
}

private def decimalToTimestamp(d: BigDecimal) = {
Expand All @@ -102,87 +102,87 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {

private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)

def castToLong: Any => Long = child.dataType match {
case StringType => nullOrCast[String, Long](_, s => try s.toLong catch {
case _: NumberFormatException => null.asInstanceOf[Long]
def castToLong: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, Long](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Long](_, t => timestampToDouble(t).toLong)
case DecimalType => nullOrCast[BigDecimal, Long](_, _.toLong)
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}

def castToInt: Any => Int = child.dataType match {
case StringType => nullOrCast[String, Int](_, s => try s.toInt catch {
case _: NumberFormatException => null.asInstanceOf[Int]
def castToInt: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, Int](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Int](_, t => timestampToDouble(t).toInt)
case DecimalType => nullOrCast[BigDecimal, Int](_, _.toInt)
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}

def castToShort: Any => Short = child.dataType match {
case StringType => nullOrCast[String, Short](_, s => try s.toShort catch {
case _: NumberFormatException => null.asInstanceOf[Short]
def castToShort: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, Short](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Short](_, t => timestampToDouble(t).toShort)
case DecimalType => nullOrCast[BigDecimal, Short](_, _.toShort)
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}

def castToByte: Any => Byte = child.dataType match {
case StringType => nullOrCast[String, Byte](_, s => try s.toByte catch {
case _: NumberFormatException => null.asInstanceOf[Byte]
def castToByte: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, Byte](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Byte](_, t => timestampToDouble(t).toByte)
case DecimalType => nullOrCast[BigDecimal, Byte](_, _.toByte)
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

def castToDecimal: Any => BigDecimal = child.dataType match {
case StringType => nullOrCast[String, BigDecimal](_, s => try s.toDouble catch {
case _: NumberFormatException => null.asInstanceOf[BigDecimal]
def castToDecimal: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, BigDecimal](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, BigDecimal](_, t => timestampToDouble(t))
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
}

def castToDouble: Any => Double = child.dataType match {
case StringType => nullOrCast[String, Double](_, s => try s.toDouble catch {
case _: NumberFormatException => null.asInstanceOf[Int]
def castToDouble: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, Double](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Double](_, t => timestampToDouble(t))
case DecimalType => nullOrCast[BigDecimal, Double](_, _.toDouble)
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}

def castToFloat: Any => Float = child.dataType match {
case StringType => nullOrCast[String, Float](_, s => try s.toFloat catch {
case _: NumberFormatException => null.asInstanceOf[Int]
def castToFloat: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean, Float](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp, Float](_, t => timestampToDouble(t).toFloat)
case DecimalType => nullOrCast[BigDecimal, Float](_, _.toFloat)
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}

def cast: Any => Any = (child.dataType, dataType) match {
case (_, StringType) => castToString
case (_, BinaryType) => castToBinary
case (_, DecimalType) => castToDecimal
case (_, TimestampType) => castToTimestamp
case (_, BooleanType) => castToBoolean
case (_, ByteType) => castToByte
case (_, ShortType) => castToShort
case (_, IntegerType) => castToInt
case (_, FloatType) => castToFloat
case (_, LongType) => castToLong
case (_, DoubleType) => castToDouble
def cast: Any => Any = dataType match {
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
}

override def apply(input: Row): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("23" cast DecimalType, 23)
checkEvaluation("23" cast ByteType, 23)
checkEvaluation("23" cast ShortType, 23)
checkEvaluation("2012-12-11" cast DoubleType, null)
checkEvaluation(Literal(123) cast IntegerType, 123)

intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
}
Expand Down

0 comments on commit 54a0489

Please sign in to comment.