Skip to content

Commit

Permalink
[SPARK-22384][SQL] Refine partition pruning when attribute is wrapped…
Browse files Browse the repository at this point in the history
… in Cast

## What changes were proposed in this pull request?

Sql below will get all partitions from metastore, which put much burden on metastore;
```
CREATE TABLE `partition_test`(`col` int) PARTITIONED BY (`pt` byte)
SELECT * FROM partition_test WHERE CAST(pt AS INT)=1
```
The reason is that the the analyzed attribute `dt` is wrapped in `Cast` and `HiveShim` fails to generate a proper partition filter.
This pr proposes to take `Cast` into consideration when generate partition filter.

## How was this patch tested?
Test added.
This pr proposes to use analyzed expressions in `HiveClientSuite`

Author: jinxing <[email protected]>

Closes #19602 from jinxing64/SPARK-22384.
  • Loading branch information
jinxing authored and cloud-fan committed Jun 5, 2018
1 parent 2c2a86b commit 93df3cd
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {

val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled

object ExtractAttribute {
def unapply(expr: Expression): Option[Attribute] = {
expr match {
case attr: Attribute => Some(attr)
case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child)
case _ => None
}
}
}

def convert(expr: Expression): Option[String] = expr match {
case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced =>
case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values))
if useAdvanced =>
Some(convertInToOr(name, values))

case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced =>
case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values))
if useAdvanced =>
Some(convertInToOr(name, values))

case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) =>
case op @ SpecialBinaryComparison(
ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) =>
Some(s"$name ${op.symbol} $value")

case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) =>
case op @ SpecialBinaryComparison(
ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) =>
Some(s"$value ${op.symbol} $name")

case And(expr1, expr2) if useAdvanced =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.LongType

// TODO: Refactor this to `HivePartitionFilteringSuite`
class HiveClientSuite(version: String)
extends HiveVersionSuite(version) with BeforeAndAfterAll {
import CatalystSqlParser._

private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname

Expand All @@ -46,8 +46,7 @@ class HiveClientSuite(version: String)
val hadoopConf = new Configuration()
hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql)
val client = buildClient(hadoopConf)
client
.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")
client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")

val partitions =
for {
Expand All @@ -66,6 +65,15 @@ class HiveClientSuite(version: String)
client
}

private def attr(name: String): Attribute = {
client.getTable("default", "test").partitionSchema.fields
.find(field => field.name.equals(name)) match {
case Some(field) => AttributeReference(field.name, field.dataType)()
case None =>
fail(s"Illegal name of partition attribute: $name")
}
}

override def beforeAll() {
super.beforeAll()
client = init(true)
Expand All @@ -74,23 +82,23 @@ class HiveClientSuite(version: String)
test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
val client = init(false)
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(parseExpression("ds=20170101")))
Seq(attr("ds") === 20170101))

assert(filteredPartitions.size == testPartitionCount)
}

test("getPartitionsByFilter: ds<=>20170101") {
// Should return all partitions where <=> is not supported
testMetastorePartitionFiltering(
"ds<=>20170101",
attr("ds") <=> 20170101,
20170101 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds=20170101") {
testMetastorePartitionFiltering(
"ds=20170101",
attr("ds") === 20170101,
20170101 to 20170101,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
Expand All @@ -100,55 +108,83 @@ class HiveClientSuite(version: String)
// Should return all partitions where h=0 because getPartitionsByFilter does not support
// comparisons to non-literal values
testMetastorePartitionFiltering(
"ds=(20170101 + 1) and h=0",
attr("ds") === (Literal(20170101) + 1) && attr("h") === 0,
20170101 to 20170103,
0 to 0,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: chunk='aa'") {
testMetastorePartitionFiltering(
"chunk='aa'",
attr("chunk") === "aa",
20170101 to 20170103,
0 to 23,
"aa" :: Nil)
}

test("getPartitionsByFilter: 20170101=ds") {
testMetastorePartitionFiltering(
"20170101=ds",
Literal(20170101) === attr("ds"),
20170101 to 20170101,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds=20170101 and h=10") {
testMetastorePartitionFiltering(
"ds=20170101 and h=10",
attr("ds") === 20170101 && attr("h") === 10,
20170101 to 20170101,
10 to 10,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") {
testMetastorePartitionFiltering(
attr("ds").cast(LongType) === 20170101L && attr("h") === 10,
20170101 to 20170101,
10 to 10,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds=20170101 or ds=20170102") {
testMetastorePartitionFiltering(
"ds=20170101 or ds=20170102",
attr("ds") === 20170101 || attr("ds") === 20170102,
20170101 to 20170102,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") {
testMetastorePartitionFiltering(
"ds in (20170102, 20170103)",
attr("ds").in(20170102, 20170103),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") {
testMetastorePartitionFiltering(
attr("ds").cast(LongType).in(20170102L, 20170103L),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") {
testMetastorePartitionFiltering(
"ds in (20170102, 20170103)",
attr("ds").in(20170102, 20170103),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
case expr @ In(v, list) if expr.inSetConvertible =>
InSet(v, list.map(_.eval(EmptyRow)).toSet)
})
}

test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)")
{
testMetastorePartitionFiltering(
attr("ds").cast(LongType).in(20170102L, 20170103L),
20170102 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
Expand All @@ -159,15 +195,15 @@ class HiveClientSuite(version: String)

test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") {
testMetastorePartitionFiltering(
"chunk in ('ab', 'ba')",
attr("chunk").in("ab", "ba"),
20170101 to 20170103,
0 to 23,
"ab" :: "ba" :: Nil)
}

test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") {
testMetastorePartitionFiltering(
"chunk in ('ab', 'ba')",
attr("chunk").in("ab", "ba"),
20170101 to 20170103,
0 to 23,
"ab" :: "ba" :: Nil, {
Expand All @@ -179,26 +215,24 @@ class HiveClientSuite(version: String)
test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb"))
testMetastorePartitionFiltering(
"(ds=20170101 and h>=8) or (ds=20170102 and h<8)",
day1 :: day2 :: Nil)
testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) ||
(attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil)
}

test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
// Day 2 should include all hours because we can't build a filter for h<(7+1)
val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb"))
testMetastorePartitionFiltering(
"(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))",
day1 :: day2 :: Nil)
testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) ||
(attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil)
}

test("getPartitionsByFilter: " +
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") {
val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba"))
val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba"))
testMetastorePartitionFiltering(
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))",
testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") &&
((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)),
day1 :: day2 :: Nil)
}

Expand All @@ -207,41 +241,41 @@ class HiveClientSuite(version: String)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedDs: Seq[Int],
expectedH: Seq[Int],
expectedChunks: Seq[String]): Unit = {
testMetastorePartitionFiltering(
filterString,
filterExpr,
(expectedDs, expectedH, expectedChunks) :: Nil,
identity)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedDs: Seq[Int],
expectedH: Seq[Int],
expectedChunks: Seq[String],
transform: Expression => Expression): Unit = {
testMetastorePartitionFiltering(
filterString,
filterExpr,
(expectedDs, expectedH, expectedChunks) :: Nil,
identity)
transform)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = {
testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity)
testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity)
}

private def testMetastorePartitionFiltering(
filterString: String,
filterExpr: Expression,
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])],
transform: Expression => Expression): Unit = {
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(
transform(parseExpression(filterString))
transform(filterExpr)
))

val expectedPartitionCount = expectedPartitionCubes.map {
Expand Down

0 comments on commit 93df3cd

Please sign in to comment.