diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index da9320ffb61c3..2187f54862ca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -480,10 +480,120 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringContains(a.name, v.toString)) + case expressions.BinaryComparison(BinaryArithmetic(left, right), Literal(v, t)) => + translateArithemiticOPFilter (predicate) + case expressions.BinaryComparison(Literal(v, t), BinaryArithmetic(left, right)) => + translateArithemiticOPFilter (predicate) + case _ => None } } + private def translateArithemiticOPFilter(predicate: Expression): Option[Filter] = { + predicate match { + case expressions.EqualTo(Add(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPEqualTo(Add(left, right), convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), Add(left, right)) => + Some(sources.ArithmeticOPEqualTo(Add(left, right), convertToScala(v, t))) + + case expressions.EqualTo(Subtract(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPEqualTo(Subtract(left, right), convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), Subtract(left, right)) => + Some(sources.ArithmeticOPEqualTo(Subtract(left, right), convertToScala(v, t))) + + case expressions.EqualTo(Multiply(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPEqualTo(Multiply(left, right), convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), Multiply(left, right)) => + Some(sources.ArithmeticOPEqualTo(Multiply(left, right), convertToScala(v, t))) + + case expressions.EqualTo(Divide(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPEqualTo(Divide(left, right), convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), Divide(left, right)) => + Some(sources.ArithmeticOPEqualTo(Divide(left, right), convertToScala(v, t))) + + case expressions.GreaterThan(Add(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThan(Add(left, right), convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), Add(left, right)) => + Some(sources.ArithmeticOPLessThan(Add(left, right), convertToScala(v, t))) + + case expressions.GreaterThan(Subtract(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThan(Subtract(left, right), convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), Subtract(left, right)) => + Some(sources.ArithmeticOPLessThan(Subtract(left, right), convertToScala(v, t))) + + case expressions.GreaterThan(Multiply(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThan(Multiply(left, right), convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), Multiply(left, right)) => + Some(sources.ArithmeticOPLessThan(Multiply(left, right), convertToScala(v, t))) + + case expressions.GreaterThan(Divide(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThan(Divide(left, right), convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), Divide(left, right)) => + Some(sources.ArithmeticOPLessThan(Divide(left, right), convertToScala(v, t))) + + case expressions.LessThan(Add(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThan(Add(left, right), convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), Add(left, right)) => + Some(sources.ArithmeticOPGreaterThan(Add(left, right), convertToScala(v, t))) + + case expressions.LessThan(Subtract(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThan(Subtract(left, right), convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), Subtract(left, right)) => + Some(sources.ArithmeticOPGreaterThan(Subtract(left, right), convertToScala(v, t))) + + case expressions.LessThan(Multiply(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThan(Multiply(left, right), convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), Multiply(left, right)) => + Some(sources.ArithmeticOPGreaterThan(Multiply(left, right), convertToScala(v, t))) + + case expressions.LessThan(Divide(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThan(Divide(left, right), convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), Divide(left, right)) => + Some(sources.ArithmeticOPGreaterThan(Divide(left, right), convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(Add(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Add(left, right), convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), Add(left, right)) => + Some(sources.ArithmeticOPLessThanOrEqual(Add(left, right), convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(Subtract(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Subtract(left, right), convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), Subtract(left, right)) => + Some(sources.ArithmeticOPLessThanOrEqual(Subtract(left, right), convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(Multiply(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Multiply(left, right), convertToScala(v, t))) + case expressions.GreaterThanOrEqual( Literal(v, t), Multiply(left, right)) => + Some(sources.ArithmeticOPLessThanOrEqual(Multiply(left, right), convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(Divide(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Divide(left, right), convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), Divide(left, right)) => + Some(sources.ArithmeticOPLessThanOrEqual(Divide(left, right), convertToScala(v, t))) + + + case expressions.LessThanOrEqual(Add(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThanOrEqual(Add(left, right), convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), Add(left, right)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Add(left, right), convertToScala(v, t))) + + case expressions.LessThanOrEqual(Subtract(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThanOrEqual(Subtract(left, right), convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), Subtract(left, right)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Subtract(left, right), convertToScala(v, t))) + + case expressions.LessThanOrEqual(Multiply(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThanOrEqual(Multiply(left, right), convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), Multiply(left, right)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Multiply(left, right), convertToScala(v, t))) + + case expressions.LessThanOrEqual(Divide(left, right), Literal(v, t)) => + Some(sources.ArithmeticOPLessThanOrEqual(Divide(left, right), convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), Divide(left, right)) => + Some(sources.ArithmeticOPGreaterThanOrEqual(Divide(left, right), convertToScala(v, t))) + } + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s * and can be handled by `relation`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index d867e144e517f..90c3383cda613 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -26,8 +26,8 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.{expressions,InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SpecificMutableRow} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ @@ -223,11 +223,57 @@ private[sql] object JDBCRDD extends Logging { } else { null } + case ArithmeticOPEqualTo(operation, value) => + getArithmeticString(operation).get + s" = ${compileValue(value)}" + case ArithmeticOPGreaterThan(operation, value) => + getArithmeticString(operation).get + s" > ${compileValue(value)}" + case ArithmeticOPGreaterThanOrEqual(operation, value) => + getArithmeticString(operation).get + s" >= ${compileValue(value)}" + case ArithmeticOPLessThan(operation, value) => + getArithmeticString(operation).get + s" < ${compileValue(value)}" + case ArithmeticOPLessThanOrEqual(operation, value) => + getArithmeticString(operation).get + s" <= ${compileValue(value)}" case _ => null }) } - + private def getArithmeticString (predicate: Expression): Option[String] = { + predicate match { + case expressions.Add(left, right) => { + val add = Seq(left, right).map(getArithmeticString(_)).flatten + if (add.size == 2) { + Some(add.map(p => s"($p)").mkString(" + ")) + } else { + None + } + } + case expressions.Subtract(left, right) => { + val subtract = Seq(left, right).map(getArithmeticString(_)).flatten + if (subtract.size == 2) { + Some(subtract.map(p => s"($p)").mkString(" - ")) + } else { + None + } + } + case expressions.Multiply(left, right) => { + val multiply = Seq(left, right).map(getArithmeticString(_)).flatten + if (multiply.size == 2) { + Some(multiply.map(p => s"($p)").mkString(" * ")) + } else { + None + } + } + case expressions.Divide(left, right) => { + val divide = Seq(left, right).map(getArithmeticString(_)).flatten + if (divide.size == 2) { + Some(divide.map(p => s"($p)").mkString(" / ")) + } else { + None + } + } + case a:Attribute => Some(a.name) + } + } /** * Build and return JDBCRDD from the given information. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 3780cbbcc9631..d818e6a167ac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic + //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -142,3 +144,43 @@ case class StringEndsWith(attribute: String, value: String) extends Filter * @since 1.3.1 */ case class StringContains(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the Arithmetic operation evaluates to a value + * equal to `value`. + * + * @since 2.0 + */ +case class ArithmeticOPEqualTo(operation: BinaryArithmetic, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the Arithmetic operation evaluates to a value + * greater than `value`. + * + * @since 2.0 + */ +case class ArithmeticOPGreaterThan(operation: BinaryArithmetic, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the Arithmetic operation evaluates to a value + * greater than or equal to `value`. + * + * @since 2.0 + */ +case class ArithmeticOPGreaterThanOrEqual(operation: BinaryArithmetic, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the Arithmetic operation evaluates to a value + * less than `value`. + * + * @since 2.0 + */ +case class ArithmeticOPLessThan(operation: BinaryArithmetic, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the Arithmetic operation evaluates to a value + * less than or equal to `value`. + * + * @since 2.0 + */ +case class ArithmeticOPLessThanOrEqual(operation: BinaryArithmetic, value: Any) extends Filter \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 1fa22e2933318..c2c3380be6ad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -203,6 +203,7 @@ class JDBCSuite extends SparkFunSuite assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) + assert(stripSparkFilter(sql("SELECT * FROM inttypes WHERE (A+C)*D-A = 15")).collect().size == 1) // This is a test to reflect discussion in SPARK-12218. // The older versions of spark have this kind of bugs in parquet data source.