From 54a048945e2bb70f33f635723da0e0c8308ac4db Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 1 Apr 2014 16:21:39 +0800 Subject: [PATCH] fix bug mapping to 0 (which is supposed to be null) when NumberFormatException occurs --- .../spark/sql/catalyst/expressions/Cast.scala | 164 +++++++++--------- .../ExpressionEvaluationSuite.scala | 2 + 2 files changed, 84 insertions(+), 82 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f5380751c0840..d03dec60bbaba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -29,39 +29,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { type EvaluatedType = Any - def nullOrCast[T, B](a: Any, func: T => B): B = if(a == null) { - null.asInstanceOf[B] + def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { + null } else { func(a.asInstanceOf[T]) } - + // UDFToString - def castToString: Any => String = child.dataType match { - case BinaryType => nullOrCast[Array[Byte], String](_, new String(_, "UTF-8")) - case _ => nullOrCast[Any, String](_, _.toString) + def castToString: Any => Any = child.dataType match { + case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => nullOrCast[Any](_, _.toString) } // BinaryConverter - def castToBinary: Any => Array[Byte] = child.dataType match { - case StringType => nullOrCast[String, Array[Byte]](_, _.getBytes("UTF-8")) + def castToBinary: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) } // UDFToBoolean - def castToBoolean: Any => Boolean = child.dataType match { - case StringType => nullOrCast[String, Boolean](_, _.length() != 0) - case TimestampType => nullOrCast[Timestamp, Boolean](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) - case LongType => nullOrCast[Long, Boolean](_, _ != 0) - case IntegerType => nullOrCast[Int, Boolean](_, _ != 0) - case ShortType => nullOrCast[Short, Boolean](_, _ != 0) - case ByteType => nullOrCast[Byte, Boolean](_, _ != 0) - case DecimalType => nullOrCast[BigDecimal, Boolean](_, _ != 0) - case DoubleType => nullOrCast[Double, Boolean](_, _ != 0) - case FloatType => nullOrCast[Float, Boolean](_, _ != 0) + def castToBoolean: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, _.length() != 0) + case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) + case LongType => nullOrCast[Long](_, _ != 0) + case IntegerType => nullOrCast[Int](_, _ != 0) + case ShortType => nullOrCast[Short](_, _ != 0) + case ByteType => nullOrCast[Byte](_, _ != 0) + case DecimalType => nullOrCast[BigDecimal](_, _ != 0) + case DoubleType => nullOrCast[Double](_, _ != 0) + case FloatType => nullOrCast[Float](_, _ != 0) } // TimestampConverter - def castToTimestamp: Any => Timestamp = child.dataType match { - case StringType => nullOrCast[String, Timestamp](_, s => { + def castToTimestamp: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => { // Throw away extra if more than 9 decimal places val periodIdx = s.indexOf("."); var n = s @@ -72,17 +72,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} }) - case BooleanType => nullOrCast[Boolean, Timestamp](_, b => new Timestamp((if(b) 1 else 0) * 1000)) - case LongType => nullOrCast[Long, Timestamp](_, l => new Timestamp(l * 1000)) - case IntegerType => nullOrCast[Int, Timestamp](_, i => new Timestamp(i * 1000)) - case ShortType => nullOrCast[Short, Timestamp](_, s => new Timestamp(s * 1000)) - case ByteType => nullOrCast[Byte, Timestamp](_, b => new Timestamp(b * 1000)) + case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) + case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) // TimestampWritable.decimalToTimestamp - case DecimalType => nullOrCast[BigDecimal, Timestamp](_, d => decimalToTimestamp(d)) + case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp - case DoubleType => nullOrCast[Double, Timestamp](_, d => decimalToTimestamp(d)) + case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) // TimestampWritable.floatToTimestamp - case FloatType => nullOrCast[Float, Timestamp](_, f => decimalToTimestamp(f)) + case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) } private def decimalToTimestamp(d: BigDecimal) = { @@ -102,87 +102,87 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000) - def castToLong: Any => Long = child.dataType match { - case StringType => nullOrCast[String, Long](_, s => try s.toLong catch { - case _: NumberFormatException => null.asInstanceOf[Long] + def castToLong: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, Long](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, Long](_, t => timestampToDouble(t).toLong) - case DecimalType => nullOrCast[BigDecimal, Long](_, _.toLong) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong) + case DecimalType => nullOrCast[BigDecimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } - def castToInt: Any => Int = child.dataType match { - case StringType => nullOrCast[String, Int](_, s => try s.toInt catch { - case _: NumberFormatException => null.asInstanceOf[Int] + def castToInt: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, Int](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, Int](_, t => timestampToDouble(t).toInt) - case DecimalType => nullOrCast[BigDecimal, Int](_, _.toInt) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt) + case DecimalType => nullOrCast[BigDecimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } - def castToShort: Any => Short = child.dataType match { - case StringType => nullOrCast[String, Short](_, s => try s.toShort catch { - case _: NumberFormatException => null.asInstanceOf[Short] + def castToShort: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, Short](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, Short](_, t => timestampToDouble(t).toShort) - case DecimalType => nullOrCast[BigDecimal, Short](_, _.toShort) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort) + case DecimalType => nullOrCast[BigDecimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } - def castToByte: Any => Byte = child.dataType match { - case StringType => nullOrCast[String, Byte](_, s => try s.toByte catch { - case _: NumberFormatException => null.asInstanceOf[Byte] + def castToByte: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, Byte](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, Byte](_, t => timestampToDouble(t).toByte) - case DecimalType => nullOrCast[BigDecimal, Byte](_, _.toByte) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte) + case DecimalType => nullOrCast[BigDecimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } - def castToDecimal: Any => BigDecimal = child.dataType match { - case StringType => nullOrCast[String, BigDecimal](_, s => try s.toDouble catch { - case _: NumberFormatException => null.asInstanceOf[BigDecimal] + def castToDecimal: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, BigDecimal](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, BigDecimal](_, t => timestampToDouble(t)) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) } - def castToDouble: Any => Double = child.dataType match { - case StringType => nullOrCast[String, Double](_, s => try s.toDouble catch { - case _: NumberFormatException => null.asInstanceOf[Int] + def castToDouble: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, Double](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, Double](_, t => timestampToDouble(t)) - case DecimalType => nullOrCast[BigDecimal, Double](_, _.toDouble) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } - def castToFloat: Any => Float = child.dataType match { - case StringType => nullOrCast[String, Float](_, s => try s.toFloat catch { - case _: NumberFormatException => null.asInstanceOf[Int] + def castToFloat: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean, Float](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp, Float](_, t => timestampToDouble(t).toFloat) - case DecimalType => nullOrCast[BigDecimal, Float](_, _.toFloat) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - def cast: Any => Any = (child.dataType, dataType) match { - case (_, StringType) => castToString - case (_, BinaryType) => castToBinary - case (_, DecimalType) => castToDecimal - case (_, TimestampType) => castToTimestamp - case (_, BooleanType) => castToBoolean - case (_, ByteType) => castToByte - case (_, ShortType) => castToShort - case (_, IntegerType) => castToInt - case (_, FloatType) => castToFloat - case (_, LongType) => castToLong - case (_, DoubleType) => castToDouble + def cast: Any => Any = dataType match { + case StringType => castToString + case BinaryType => castToBinary + case DecimalType => castToDecimal + case TimestampType => castToTimestamp + case BooleanType => castToBoolean + case ByteType => castToByte + case ShortType => castToShort + case IntegerType => castToInt + case FloatType => castToFloat + case LongType => castToLong + case DoubleType => castToDouble } override def apply(input: Row): Any = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 0e57f03a30a28..43876033d327b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -232,6 +232,8 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast DecimalType, 23) checkEvaluation("23" cast ByteType, 23) checkEvaluation("23" cast ShortType, 23) + checkEvaluation("2012-12-11" cast DoubleType, null) + checkEvaluation(Literal(123) cast IntegerType, 123) intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} }