Skip to content

Commit

Permalink
Add new query hint NO_COLLAPSE.
Browse files Browse the repository at this point in the history
  • Loading branch information
ptkool committed Apr 20, 2017
1 parent 86d251c commit 3f1e6a1
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 15 deletions.
8 changes: 8 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ def nanvl(col1, col2):
return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2)))


@since(2.2)
def no_collapse(df):
"""Marks a DataFrame as small enough for use in broadcast joins."""

sc = SparkContext._active_spark_context
return DataFrame(sc._jvm.functions.no_collapse(df._jdf), df.sql_ctx)


@since(1.4)
def rand(seed=None):
"""Generates a random column with independent and identically distributed (i.i.d.) samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,13 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
child.stats(conf).copy(isBroadcastable = true)
}

/**
* A hint for the optimizer that we should not merge two projections.
*/
case class NoCollapseHint(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

/**
* A general hint for the child. This node will be eliminated post analysis.
* A pair of (name, parameters).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoCollapseHint}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseProjectSuite extends PlanTest {
Expand Down Expand Up @@ -119,4 +119,14 @@ class CollapseProjectSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("do not collapse projects with onceOnly expressions") {
val query = NoCollapseHint(testRelation.select(('a * 10).as('a_times_10)))
.select(('a_times_10 + 1).as('a_times_10_plus_1), ('a_times_10 + 2).as('a_times_10_plus_2))

val optimized = Optimize.execute(query.analyze)
val correctAnswer = query.analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -537,5 +537,10 @@ class PlanParserSuite extends PlanTest {
comparePlans(
parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))

comparePlans(
parsePlan("SELECT a FROM (SELECT /*+ NO_COLLAPSE */ * FROM t) t1"),
SubqueryAlias("t1", Hint("NO_COLLAPSE", Seq.empty, table("t").select(star())))
.select('a))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case NoCollapseHint(child) => planLater(child) :: Nil
case _ => Nil
}
}
Expand Down
43 changes: 29 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ package org.apache.spark.sql

import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, NoCollapseHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1007,21 +1006,37 @@ object functions {
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }

/**
* Marks a DataFrame as small enough for use in broadcast joins.
*
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
* {{{
* // left and right are DataFrames
* left.join(broadcast(right), "joinKey")
* }}}
*
* @group normal_funcs
* @since 1.5.0
*/
* Marks a DataFrame as small enough for use in broadcast joins.
*
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
* {{{
* // left and right are DataFrames
* left.join(broadcast(right), "joinKey")
* }}}
*
* @group normal_funcs
* @since 1.5.0
*/
def broadcast[T](df: Dataset[T]): Dataset[T] = {
Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc)
}

/**
* Marks a DataFrame as small enough for use in broadcast joins.
*
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
* {{{
* // left and right are DataFrames
* left.join(broadcast(right), "joinKey")
* }}}
*
* @group normal_funcs
* @since 1.5.0
*/
def no_collapse[T](df: Dataset[T]): Dataset[T] = {
Dataset[T](df.sparkSession, NoCollapseHint(df.logicalPlan))(df.exprEnc)
}

/**
* Returns the first column that is not null, or null if all inputs are null.
*
Expand Down

0 comments on commit 3f1e6a1

Please sign in to comment.