From 93df3cd03503fca7745141fbd2676b8bf70fe92f Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 5 Jun 2018 11:32:42 -0700 Subject: [PATCH] [SPARK-22384][SQL] Refine partition pruning when attribute is wrapped in Cast ## What changes were proposed in this pull request? Sql below will get all partitions from metastore, which put much burden on metastore; ``` CREATE TABLE `partition_test`(`col` int) PARTITIONED BY (`pt` byte) SELECT * FROM partition_test WHERE CAST(pt AS INT)=1 ``` The reason is that the the analyzed attribute `dt` is wrapped in `Cast` and `HiveShim` fails to generate a proper partition filter. This pr proposes to take `Cast` into consideration when generate partition filter. ## How was this patch tested? Test added. This pr proposes to use analyzed expressions in `HiveClientSuite` Author: jinxing Closes #19602 from jinxing64/SPARK-22384. --- .../spark/sql/hive/client/HiveShim.scala | 23 +++- .../sql/hive/client/HiveClientSuite.scala | 102 ++++++++++++------ 2 files changed, 86 insertions(+), 39 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 948ba542b5733..130e258e78ca2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + object ExtractAttribute { + def unapply(expr: Expression): Option[Attribute] = { + expr match { + case attr: Attribute => Some(attr) + case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case _ => None + } + } + } + def convert(expr: Expression): Option[String] = expr match { - case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + case op @ SpecialBinaryComparison( + ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) => Some(s"$name ${op.symbol} $value") - case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + case op @ SpecialBinaryComparison( + ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) => Some(s"$value ${op.symbol} $name") case And(expr1, expr2) if useAdvanced => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index f991352b207d4..55275f6b37945 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.LongType // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { - import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname @@ -46,8 +46,7 @@ class HiveClientSuite(version: String) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) val client = buildClient(hadoopConf) - client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") val partitions = for { @@ -66,6 +65,15 @@ class HiveClientSuite(version: String) client } + private def attr(name: String): Attribute = { + client.getTable("default", "test").partitionSchema.fields + .find(field => field.name.equals(name)) match { + case Some(field) => AttributeReference(field.name, field.dataType)() + case None => + fail(s"Illegal name of partition attribute: $name") + } + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -74,7 +82,7 @@ class HiveClientSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(parseExpression("ds=20170101"))) + Seq(attr("ds") === 20170101)) assert(filteredPartitions.size == testPartitionCount) } @@ -82,7 +90,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported testMetastorePartitionFiltering( - "ds<=>20170101", + attr("ds") <=> 20170101, 20170101 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -90,7 +98,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( - "ds=20170101", + attr("ds") === 20170101, 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -100,7 +108,7 @@ class HiveClientSuite(version: String) // Should return all partitions where h=0 because getPartitionsByFilter does not support // comparisons to non-literal values testMetastorePartitionFiltering( - "ds=(20170101 + 1) and h=0", + attr("ds") === (Literal(20170101) + 1) && attr("h") === 0, 20170101 to 20170103, 0 to 0, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -108,7 +116,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( - "chunk='aa'", + attr("chunk") === "aa", 20170101 to 20170103, 0 to 23, "aa" :: Nil) @@ -116,7 +124,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( - "20170101=ds", + Literal(20170101) === attr("ds"), 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -124,7 +132,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( - "ds=20170101 and h=10", + attr("ds") === 20170101 && attr("h") === 10, + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, 10 to 10, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -132,7 +148,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( - "ds=20170101 or ds=20170102", + attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -140,7 +156,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -148,7 +172,19 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") + { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -159,7 +195,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -167,7 +203,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -179,26 +215,24 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering( - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } @@ -207,41 +241,41 @@ class HiveClientSuite(version: String) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String]): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], transform: Expression => Expression): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + transform) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], transform: Expression => Expression): Unit = { val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( - transform(parseExpression(filterString)) + transform(filterExpr) )) val expectedPartitionCount = expectedPartitionCubes.map {