Skip to content

Commit

Permalink
[SPARK-10195] [SQL] Data sources Filter should not expose internal types
Browse files Browse the repository at this point in the history
Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces. This is a problem because types like UTF8String are not stable developer APIs and should not be exposed to third-parties.

This issue caused incompatibilities when upgrading our `spark-redshift` library to work against Spark 1.5.0.  To avoid these issues in the future we should only expose public types through these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add the appropriate conversions.

Author: Josh Rosen <[email protected]>

Closes #8403 from JoshRosen/datasources-internal-vs-external-types.
  • Loading branch information
JoshRosen authored and rxin committed Aug 25, 2015
1 parent 2f493f7 commit 7bc9a8c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.{InternalRow, expressions}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
Expand Down Expand Up @@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
*/
protected[sql] def selectFilters(filters: Seq[Expression]) = {
def translate(predicate: Expression): Option[Filter] = predicate match {
case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
Some(sources.EqualTo(a.name, v))
case expressions.EqualTo(Literal(v, _), a: Attribute) =>
Some(sources.EqualTo(a.name, v))

case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
Some(sources.EqualNullSafe(a.name, v))
case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
Some(sources.EqualNullSafe(a.name, v))

case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThan(a.name, v))
case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
Some(sources.LessThan(a.name, v))

case expressions.LessThan(a: Attribute, Literal(v, _)) =>
Some(sources.LessThan(a.name, v))
case expressions.LessThan(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThan(a.name, v))

case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThanOrEqual(a.name, v))
case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, v))

case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.LessThanOrEqual(a.name, v))
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, v))
case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))
case expressions.EqualTo(Literal(v, t), a: Attribute) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))

case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))

case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
Some(sources.GreaterThan(a.name, convertToScala(v, t)))
case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
Some(sources.LessThan(a.name, convertToScala(v, t)))

case expressions.LessThan(a: Attribute, Literal(v, t)) =>
Some(sources.LessThan(a.name, convertToScala(v, t)))
case expressions.LessThan(Literal(v, t), a: Attribute) =>
Some(sources.GreaterThan(a.name, convertToScala(v, t)))

case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))

case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))

case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, set.toArray.map(toScala)))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
Some(sources.In(a.name, hSet.toArray))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, hSet.toArray.map(toScala)))

case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
case stringValue: String => s"'${escapeSql(stringValue)}'"
case _ => value
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
Expand Down Expand Up @@ -65,7 +64,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Expand All @@ -86,7 +85,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Expand All @@ -104,7 +103,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.lt(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -121,7 +121,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.ltEq(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -138,7 +139,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.gt(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -155,7 +157,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.gtEq(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -177,7 +180,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
case BinaryType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.existentials

import org.apache.spark.rdd.RDD
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
case StringStartsWith("c", v) => _.startsWith(v)
case StringEndsWith("c", v) => _.endsWith(v)
case StringContains("c", v) => _.contains(v)
case EqualTo("c", v: String) => _.equals(v)
case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
case _ => (c: String) => true
}

Expand Down Expand Up @@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)

testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)

def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution
Expand Down

0 comments on commit 7bc9a8c

Please sign in to comment.