Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12506][SPARK-12126][SQL]use CatalystScan for JDBCRelation #11005

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.util.Properties

import org.apache.spark.sql.catalyst.expressions._
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix an import order.


import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils

import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -186,48 +187,227 @@ private[sql] object JDBCRDD extends Logging {
if (value == null) null else StringUtils.replace(value, "'", "''")

/**
* Turns a single Filter into a String representing a SQL expression.
* Returns None for an unhandled filter.
* Turns a single predicater into a String representing a SQL expression.
* Returns None for an unhandled predicate.
*/
private def compileFilter(f: Filter): Option[String] = {
Option(f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case EqualNullSafe(attr, value) =>
s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " +
s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))"
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
case IsNull(attr) => s"$attr IS NULL"
case IsNotNull(attr) => s"$attr IS NOT NULL"
case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'"
case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'"
case StringContains(attr, value) => s"${attr} LIKE '%${value}%'"
case In(attr, value) => s"$attr IN (${compileValue(value)})"
case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null)
case Or(f1, f2) =>
// We can't compile Or filter unless both sub-filters are compiled successfully.
// It applies too for the following And filter.
// If we can make sure compileFilter supports all filters, we can remove this check.
val or = Seq(f1, f2).map(compileFilter(_)).flatten
if (or.size == 2) {
or.map(p => s"($p)").mkString(" OR ")
private def compilePredicate(predicate: Expression): Option[String] = {
Option(predicate match {
case expressions.EqualTo(a:Attribute, Literal(v, t)) =>
a.name + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Literal(v, t), a: Attribute) =>
a.name + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.EqualNullSafe(a:Attribute, Literal(v, t)) =>
val s = compileValue (CatalystTypeConverters.convertToScala (v,t))
"(NOT ( " + a.name + " != " + s + " OR " + a.name + " IS NULL OR " + s + " IS NULL ) " +
"OR ( " + a.name + " IS NULL AND " + s + " IS NULL))"
case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
val s = compileValue (CatalystTypeConverters.convertToScala (v,t))
"(NOT ( " + a.name + " != " + s + " OR " + a.name + " IS NULL OR " + s + " IS NULL ) OR ( " +
a.name + " IS NULL AND " + s + " IS NULL))"

case expressions.GreaterThan(a:Attribute, Literal(v, t)) =>
a.name + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
a.name + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.LessThan(a:Attribute, Literal(v, t)) =>
a.name + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Literal(v, t), a: Attribute) =>
a.name + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.GreaterThanOrEqual(a:Attribute, Literal(v, t)) =>
a.name + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
a.name + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.LessThanOrEqual(a:Attribute, Literal(v, t)) =>
a.name + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
a.name + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.IsNull(a:Attribute) =>
a.name + " IS NULL"
case expressions.IsNotNull( a: Attribute) =>
a.name + " IS NOT NULL"


case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
a.name + " LIKE '"+ v.toString +"%'"

case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
a.name + " LIKE '%" + v.toString + "'"

case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
a.name + " LIKE '%" + v.toString + "%'"

case expressions.InSet(a: Attribute, set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
a.name + s" IN (${compileValue(set.toArray.map(toScala))})"

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)

a.name + s" IN (${compileValue(hSet.toArray.map(toScala))})"

case expressions.And(left, right) =>
val and = Seq(left, right).map(compilePredicate(_)).flatten
if (and.size == 2) {
and.map(p => s"($p)").mkString(" AND ")
} else {
null
}
case And(f1, f2) =>
val and = Seq(f1, f2).map(compileFilter(_)).flatten
if (and.size == 2) {
and.map(p => s"($p)").mkString(" AND ")

case expressions.Or(left, right) =>
val or = Seq(left, right).map(compilePredicate(_)).flatten
if (or.size == 2) {
or.map(p => s"($p)").mkString(" OR ")
} else {
null
}

case expressions.Not(f) => compilePredicate(f).map(p => s"(NOT ($p))").getOrElse(null)

case BinaryComparison(BinaryArithmetic(left, right), Literal(v, t)) =>
translateArithemiticOPFilter (predicate)
case BinaryComparison(Literal(v, t), BinaryArithmetic(left, right)) =>
translateArithemiticOPFilter (predicate)

case _ => null
})
}

private def translateArithemiticOPFilter(predicate: Expression): String = {
predicate match {
case expressions.EqualTo(Add(left, right), Literal(v, t)) =>
getArithmeticString(Add(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Literal(v, t), Add(left, right)) =>
getArithmeticString(Add(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Subtract(left, right), Literal(v, t)) =>
getArithmeticString(Subtract(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Literal(v, t), Subtract(left, right)) =>
getArithmeticString(Subtract(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Multiply(left, right), Literal(v, t)) =>
getArithmeticString(Multiply(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Literal(v, t), Multiply(left, right)) =>
getArithmeticString(Multiply(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Divide(left, right), Literal(v, t)) =>
getArithmeticString(Divide(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.EqualTo(Literal(v, t), Divide(left, right)) =>
getArithmeticString(Divide(left, right)).get + s" = ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.GreaterThan(Add(left, right), Literal(v, t)) =>
getArithmeticString(Add(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Literal(v, t), Add(left, right)) =>
getArithmeticString(Add(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Subtract(left, right), Literal(v, t)) =>
getArithmeticString(Subtract(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Literal(v, t), Subtract(left, right)) =>
getArithmeticString(Subtract(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Multiply(left, right), Literal(v, t)) =>
getArithmeticString(Multiply(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Literal(v, t), Multiply(left, right)) =>
getArithmeticString(Multiply(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Divide(left, right), Literal(v, t)) =>
getArithmeticString(Divide(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThan(Literal(v, t), Divide(left, right)) =>
getArithmeticString(Divide(left, right)).get + s" > ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.LessThan(Add(left, right), Literal(v, t)) =>
getArithmeticString(Add(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Literal(v, t), Add(left, right)) =>
getArithmeticString(Add(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Subtract(left, right), Literal(v, t)) =>
getArithmeticString(Subtract(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Literal(v, t), Subtract(left, right)) =>
getArithmeticString(Subtract(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Multiply(left, right), Literal(v, t)) =>
getArithmeticString(Multiply(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Literal(v, t), Multiply(left, right)) =>
getArithmeticString(Multiply(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Divide(left, right), Literal(v, t)) =>
getArithmeticString(Divide(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThan(Literal(v, t), Divide(left, right)) =>
getArithmeticString(Divide(left, right)).get + s" < ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.GreaterThanOrEqual(Add(left, right), Literal(v, t)) =>
getArithmeticString(Add(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Literal(v, t), Add(left, right)) =>
getArithmeticString(Add(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Subtract(left, right), Literal(v, t)) =>
getArithmeticString(Subtract(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Literal(v, t), Subtract(left, right)) =>
getArithmeticString(Subtract(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Multiply(left, right), Literal(v, t)) =>
getArithmeticString(Multiply(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Literal(v, t), Multiply(left, right)) =>
getArithmeticString(Multiply(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Divide(left, right), Literal(v, t)) =>
getArithmeticString(Divide(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.GreaterThanOrEqual(Literal(v, t), Divide(left, right)) =>
getArithmeticString(Divide(left, right)).get + s" >= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"

case expressions.LessThanOrEqual(Add(left, right), Literal(v, t)) =>
getArithmeticString(Add(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Literal(v, t), Add(left, right)) =>
getArithmeticString(Add(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Subtract(left, right), Literal(v, t)) =>
getArithmeticString(Subtract(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Literal(v, t), Subtract(left, right)) =>
getArithmeticString(Subtract(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Multiply(left, right), Literal(v, t)) =>
getArithmeticString(Multiply(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Literal(v, t), Multiply(left, right)) =>
getArithmeticString(Multiply(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Divide(left, right), Literal(v, t)) =>
getArithmeticString(Divide(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
case expressions.LessThanOrEqual(Literal(v, t), Divide(left, right)) =>
getArithmeticString(Divide(left, right)).get + s" <= ${compileValue (CatalystTypeConverters.convertToScala (v,t))}"
}
}

private def getArithmeticString (predicate: Expression): Option[String] = {
predicate match {
case expressions.Add(left, right) => {
val add = Seq(left, right).map(getArithmeticString(_)).flatten
if (add.size == 2) {
Some(add.map(p => s"($p)").mkString(" + "))
} else {
None
}
}
case expressions.Subtract(left, right) => {
val subtract = Seq(left, right).map(getArithmeticString(_)).flatten
if (subtract.size == 2) {
Some(subtract.map(p => s"($p)").mkString(" - "))
} else {
None
}
}
case expressions.Multiply(left, right) => {
val multiply = Seq(left, right).map(getArithmeticString(_)).flatten
if (multiply.size == 2) {
Some(multiply.map(p => s"($p)").mkString(" * "))
} else {
None
}
}
case expressions.Divide(left, right) => {
val divide = Seq(left, right).map(getArithmeticString(_)).flatten
if (divide.size == 2) {
Some(divide.map(p => s"($p)").mkString(" / "))
} else {
None
}
}
case a:AttributeReference =>
Some(a.name)
}
}

/**
* Build and return JDBCRDD from the given information.
Expand All @@ -237,7 +417,7 @@ private[sql] object JDBCRDD extends Logging {
* @param url - The JDBC url to connect to.
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
* @param requiredColumns - The names of the columns to SELECT.
* @param filters - The filters to include in all WHERE clauses.
* @param predicates - The predicates to include in all WHERE clauses.
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
*
Expand All @@ -249,18 +429,19 @@ private[sql] object JDBCRDD extends Logging {
url: String,
properties: Properties,
fqTable: String,
requiredColumns: Array[String],
filters: Array[Filter],
requiredColumns: Seq[Attribute],
predicates: Seq[Expression],
parts: Array[Partition]): RDD[InternalRow] = {
val dialect = JdbcDialects.get(url)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
val columns: Array[String] = (requiredColumns map (_.prettyString)).toArray
val quotedColumns= columns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
JdbcUtils.createConnectionFactory(url, properties),
pruneSchema(schema, requiredColumns),
pruneSchema(schema, columns),
fqTable,
quotedColumns,
filters,
predicates,
parts,
url,
properties)
Expand All @@ -278,7 +459,7 @@ private[sql] class JDBCRDD(
schema: StructType,
fqTable: String,
columns: Array[String],
filters: Array[Filter],
predicates: Seq[Expression],
partitions: Array[Partition],
url: String,
properties: Properties)
Expand All @@ -302,7 +483,7 @@ private[sql] class JDBCRDD(
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
private val filterWhereClause: String =
filters.map(JDBCRDD.compileFilter).flatten.mkString(" AND ")
predicates.map(JDBCRDD.compilePredicate).flatten.mkString(" AND ")

/**
* A WHERE clause representing both `filters`, if any, and the current partition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Partition
Expand Down Expand Up @@ -83,14 +85,14 @@ private[sql] case class JDBCRelation(
parts: Array[Partition],
properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
extends BaseRelation
with PrunedFilteredScan
with CatalystScan
with InsertableRelation {

override val needConversion: Boolean = false

override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
override def buildScan(requiredColumns: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sqlContext.sparkContext,
Expand All @@ -99,7 +101,7 @@ private[sql] case class JDBCRelation(
properties,
table,
requiredColumns,
filters,
predicates,
parts).asInstanceOf[RDD[Row]]
}

Expand Down
Loading