Skip to content

Commit

Permalink
Adds stricter rules for Parquet filters with null
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Nov 19, 2014
1 parent 397d3aa commit de7de28
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ object Literal {
}
}

object NonNullLiteral {
def unapply(literal: Literal): Option[(Any, DataType)] = {
Option(literal.value).map(_ => (literal.value, literal.dataType))
}
}

/**
* Extractor for retrieving Int literals.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,34 @@ private[sql] object ParquetFilters {
case DoubleType =>
(n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
case BinaryType =>
(n: String, v: Any) =>
FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull)
}

val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
(n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
case IntegerType =>
(n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer])
case LongType =>
(n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
(n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull)
}

val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
Expand Down Expand Up @@ -126,30 +149,45 @@ private[sql] object ParquetFilters {
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}

// NOTE:
//
// For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`,
// which can be casted to `false` implicitly. Please refer to the `eval` method of these
// operators and the `SimplifyFilters` rule for details.
predicate match {
case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType =>
case IsNull(NamedExpression(name, dataType)) =>
makeEq.lift(dataType).map(_(name, null))
case IsNotNull(NamedExpression(name, dataType)) =>
makeNotEq.lift(dataType).map(_(name, null))

case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeEq.lift(dataType).map(_(name, value))
case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType =>
case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeEq.lift(dataType).map(_(name, value))

case LessThan(NamedExpression(name, _), Literal(value, dataType)) =>
case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) =>
makeNotEq.lift(dataType).map(_(name, value))
case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) =>
makeNotEq.lift(dataType).map(_(name, value))

case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeLt.lift(dataType).map(_(name, value))
case LessThan(Literal(value, dataType), NamedExpression(name, _)) =>
case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeGt.lift(dataType).map(_(name, value))

case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeLtEq.lift(dataType).map(_(name, value))
case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeGtEq.lift(dataType).map(_(name, value))

case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) =>
case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeGt.lift(dataType).map(_(name, value))
case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) =>
case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeLt.lift(dataType).map(_(name, value))

case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
makeGtEq.lift(dataType).map(_(name, value))
case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
makeLtEq.lift(dataType).map(_(name, value))

case And(lhs, rhs) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.parquet

import _root_.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
import parquet.filter2.predicate.{FilterPredicate, Operators}
import parquet.hadoop.ParquetFileWriter
import parquet.hadoop.util.ContextUtil
import parquet.io.api.Binary

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -85,6 +86,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
TestData // Load test data tables.

var testRDD: SchemaRDD = null
var originalParquetFilterPushdownEnabled = TestSQLContext.parquetFilterPushDown

override def beforeAll() {
ParquetTestData.writeFile()
Expand All @@ -109,13 +111,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
Utils.deleteRecursively(ParquetTestData.testNestedDir3)
Utils.deleteRecursively(ParquetTestData.testNestedDir4)
// here we should also unregister the table??

setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString)
}

test("Read/Write All Types") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
val data = sparkContext.parallelize(range)
.map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
val data = sparkContext.parallelize(range).map { x =>
parquet.AllDataTypes(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)
}

data.saveAsParquetFile(tempDir)

Expand Down Expand Up @@ -260,14 +266,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
test("Read/Write All Types with non-primitive type") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
val data = sparkContext.parallelize(range)
.map(x => AllDataTypesWithNonPrimitiveType(
val data = sparkContext.parallelize(range).map { x =>
parquet.AllDataTypesWithNonPrimitiveType(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
(0 until x),
(0 until x).map(Option(_).filter(_ % 3 == 0)),
(0 until x).map(i => i -> i.toLong).toMap,
(0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
Data((0 until x), Nested(x, s"$x"))))
parquet.Data((0 until x), parquet.Nested(x, s"$x")))
}
data.saveAsParquetFile(tempDir)

checkAnswer(
Expand Down Expand Up @@ -420,7 +427,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}

test("save and load case class RDD with nulls as parquet") {
val data = NullReflectData(null, null, null, null, null)
val data = parquet.NullReflectData(null, null, null, null, null)
val rdd = sparkContext.parallelize(data :: Nil)

val file = getTempFilePath("parquet")
Expand All @@ -435,7 +442,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}

test("save and load case class RDD with Nones as parquet") {
val data = OptionalReflectData(None, None, None, None, None)
val data = parquet.OptionalReflectData(None, None, None, None, None)
val rdd = sparkContext.parallelize(data :: Nil)

val file = getTempFilePath("parquet")
Expand Down Expand Up @@ -938,4 +945,104 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
}
}

def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = {
val filter = ParquetFilters.createFilter(predicate)
assert(filter.isDefined)
assert(filter.get.getClass == filterClass)
}

test("Pushdown IsNull predicate") {
checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]])
checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]])
checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]])
checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]])
checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]])
checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]])
}

test("Pushdown IsNotNull predicate") {
checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]])
checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]])
checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]])
checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]])
checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]])
checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]])
}

test("Pushdown EqualTo predicate") {
checkFilter('a.int === 0, classOf[Operators.Eq[Integer]])
checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]])
checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]])
checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]])
checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]])
checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]])
}

test("Pushdown Not(EqualTo) predicate") {
checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]])
checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]])
checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]])
checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]])
checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]])
checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]])
}

test("Pushdown LessThan predicate") {
checkFilter('a.int < 0, classOf[Operators.Lt[Integer]])
checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]])
checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]])
checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]])
checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]])
checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]])
}

test("Pushdown LessThanOrEqual predicate") {
checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]])
checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]])
checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]])
checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]])
checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]])
checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]])
}

test("Pushdown GreaterThan predicate") {
checkFilter('a.int > 0, classOf[Operators.Gt[Integer]])
checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]])
checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]])
checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]])
checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]])
checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]])
}

test("Pushdown GreaterThanOrEqual predicate") {
checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]])
checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]])
checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]])
checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]])
checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]])
checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]])
}

test("Comparison with null should not be pushed down") {
val predicates = Seq(
'a.int === null,
!('a.int === null),

Literal(null) === 'a.int,
!(Literal(null) === 'a.int),

'a.int < null,
'a.int <= null,
'a.int > null,
'a.int >= null,

Literal(null) < 'a.int,
Literal(null) <= 'a.int,
Literal(null) > 'a.int,
Literal(null) >= 'a.int
)

predicates.foreach(p => assert(ParquetFilters.createFilter(p).isEmpty))
}
}

0 comments on commit de7de28

Please sign in to comment.