diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 193dc6b6546b5..bded3b664d8c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,15 +62,15 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if !e.validInputTypes => + e.failAnalysis( + s"cannot resolve '${t.prettyString}' due to data type mismatch: " + + e.typeMismatchErrorMessage.get) + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( s"Could not resolve window function '$name'. " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index edcc918bfe921..ba39a0ef2e2be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -407,7 +407,7 @@ trait HiveTypeCoercion { Union(newLeft, newRight) // fix decimal precision for expressions - case q => q.transformExpressions { + case q => q.transformExpressionsUp { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d19928784442e..56246a2bdc6b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -86,12 +86,16 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + + def typeMismatchErrorMessage: Option[String] = None + + def validInputTypes: Boolean = typeMismatchErrorMessage.isEmpty } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String + def symbol: String = sys.error(s"BinaryExpressions must either override toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -106,6 +110,10 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + + override def foldable: Boolean = child.foldable + + override def nullable: Boolean = child.nullable } // TODO Semantically we probably not need GroupExpression @@ -125,7 +133,9 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { * so that the proper type conversions can be performed in the analyzer. */ trait ExpectsInputTypes { + self: Expression => def expectedChildTypes: Seq[DataType] + override def validInputTypes: Boolean = children.map(_.dataType) == expectedChildTypes } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f2299d5db6e9f..e96ad4289d62b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,72 +17,87 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must either override eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def dataType: DataType = child.dataType + override def toString: String = s"-$child" + + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForNumericExpr(child.dataType, "todo") + } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForNumericExpr(child.dataType, "todo") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + private lazy val numeric = TypeUtils.getNumeric(child.dataType) + + protected override def evalInternal(evalE: Any) = { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) + } +} + +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def dataType: DataType = child.dataType + override def toString: String = s"Abs($child)" + + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForNumericExpr(child.dataType, "todo") } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } abstract class BinaryArithmetic extends BinaryExpression { self: Product => - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) + override def dataType: DataType = left.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + } else { + errorMessageInternal(left.dataType) } - left.dataType } + protected def errorMessageInternal(t: DataType): Option[String] + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -97,88 +112,84 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must either override eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") } else { - numeric.plus(evalE1, evalE2) + None } } } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = + numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") } else { - numeric.minus(evalE1, evalE2) + None } } } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = + numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") } else { - numeric.times(evalE1, evalE2) + None } } } + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = + numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") + } else { + None + } + } + } + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -198,13 +209,21 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def nullable: Boolean = true - lazy val integral = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForNumericExpr(t, "todo").orElse { + if (DecimalType.isFixed(t)) { + Some("todo") + } else { + None + } + } + } + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -228,7 +247,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - lazy val and: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForBitwiseExpr(t, "todo") + } + + private lazy val and: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -237,10 +260,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = and(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = + and(evalE1, evalE2) } /** @@ -249,7 +272,11 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - lazy val or: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForBitwiseExpr(t, "todo") + } + + private lazy val or: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -258,10 +285,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = or(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = + or(evalE1, evalE2) } /** @@ -270,7 +297,11 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - lazy val xor: (Any, Any) => Any = dataType match { + protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForBitwiseExpr(t, "todo") + } + + private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -279,23 +310,24 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = + xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - +case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"~$child" - lazy val not: (Any) => Any = dataType match { + override def typeMismatchErrorMessage: Option[String] = { + TypeUtils.checkForBitwiseExpr(child.dataType, "todo") + } + + private lazy val not: (Any) => Any = dataType match { case ByteType => ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] case ShortType => @@ -304,42 +336,23 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] case LongType => ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - not(evalE) - } - } + protected override def evalInternal(evalE: Any) = not(evalE) } -case class MaxOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable && right.nullable + override def dataType: DataType = left.dataType - override def children: Seq[Expression] = left :: right :: Nil + private lazy val ordering = TypeUtils.getOrdering(dataType) - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + } else { + TypeUtils.checkForOrderingExpr(dataType, "todo") } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") } override def eval(input: Row): Any = { @@ -361,29 +374,18 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MinOf(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable && right.nullable + override def dataType: DataType = left.dataType - override def children: Seq[Expression] = left :: right :: Nil + private lazy val ordering = TypeUtils.getOrdering(dataType) - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}") + } else { + TypeUtils.checkForOrderingExpr(dataType, "todo") } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") } override def eval(input: Row): Any = { @@ -404,28 +406,3 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MinOf($left, $right)" } - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - numeric.abs(evalE) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 01f62ba0442e9..2b2a994f843c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -29,17 +29,10 @@ import org.apache.spark.sql.types._ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - override def symbol: String = null override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"$name($left, $right)" - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - override def dataType: DataType = DoubleType override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 41b422346a02d..ff235f0ab58eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -31,7 +31,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4f422d69c4382..03af95c508fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -70,8 +69,6 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) @@ -171,6 +168,16 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => + + override def typeMismatchErrorMessage: Option[String] = { + if (left.dataType != right.dataType) { + Some(s"differing types in BinaryComparisons, ${left.dataType}, ${right.dataType}") + } else { + errorMessageInternal(left.dataType) + } + } + + protected def errorMessageInternal(t: DataType): Option[String] = None } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { @@ -210,17 +217,12 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -239,17 +241,12 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -268,17 +265,12 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -297,17 +289,12 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def errorMessageInternal(t: DataType) = { + TypeUtils.checkForOrderingExpr(t, "todo") } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -329,16 +316,16 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + override def typeMismatchErrorMessage: Option[String] = { + if (trueValue.dataType != falseValue.dataType) { + Some(s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}") + } else { + None } - trueValue.dataType } + override def dataType: DataType = trueValue.dataType + override def eval(input: Row): Any = { if (true == predicate.eval(input)) { trueValue.eval(input) @@ -368,12 +355,7 @@ trait CaseWhenLike extends Expression { def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") - } - valueTypes.head - } + override def dataType: DataType = valueTypes.head override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. @@ -395,10 +377,15 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual + override def typeMismatchErrorMessage: Option[String] = { + if (!whenList.forall(_.dataType == BooleanType)) { + Some(s"WHEN expressions should all be boolean type") + } else if (!valueTypesEqual) { + Some("THEN and ELSE expressions should all be same type") + } else { + None + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { @@ -441,9 +428,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual && - (key +: whenList).map(_.dataType).distinct.size == 1 + override def typeMismatchErrorMessage: Option[String] = { + if (!valueTypesEqual) { + Some("THEN and ELSE expressions should all be same type") + } else { + None + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 0000000000000..511d876e87e26 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.types.{AtomicType, IntegralType, NumericType, DataType} + +/** + * Helper function to check valid data types + */ +object TypeUtils { + + def checkForNumericExpr(t: DataType, errorMsg: => String): Option[String] = { + if (t.isInstanceOf[NumericType]) { + None + } else { + Some(errorMsg) + } + } + + def checkForBitwiseExpr(t: DataType, errorMsg: => String): Option[String] = { + if (t.isInstanceOf[IntegralType]) { + None + } else { + Some(errorMsg) + } + } + + def checkForOrderingExpr(t: DataType, errorMsg: => String): Option[String] = { + if (t.isInstanceOf[AtomicType]) { + None + } else { + Some(errorMsg) + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 1ba3a2686639f..74677ddfcad65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,7 +107,7 @@ protected[sql] abstract class AtomicType extends DataType { abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType]