diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 59cb26d5e6c36..efb2eba655e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -301,6 +302,8 @@ package object dsl { def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) + def filter[T : Encoder](func: FilterFunction[T]): LogicalPlan = TypedFilter(func, logicalPlan) + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 56f096f3ecf8c..5fc99a3a57c0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,18 +39,19 @@ class TypedFilterOptimizationSuite extends PlanTest { implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + val testRelation = LocalRelation('_1.int, '_2.int) + test("filter after serialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -58,10 +60,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter after serialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze @@ -70,17 +71,16 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -89,10 +89,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze @@ -101,21 +100,89 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("back to back filter with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: (Int, Int)) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 1) } test("back to back filter with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: OtherTuple) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("back to back FilterFunction with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("back to back FilterFunction with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("FilterFunction and filter with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("FilterFunction and filter with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: OtherTuple) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("filter and FilterFunction with the same object type") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("filter and FilterFunction with different object types") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 2) }