shrink the commits
chenghao-intel committed Apr 14, 2015
1 parent 77eeb10 commit ca5e7f4
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -59,6 +58,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
Expand Down Expand Up @@ -473,10 +473,47 @@ class Analyzer(
object ImplicitGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(Seq(Alias(g: Generator, _)), child) =>
Generate(g, join = false, outer = false, None, child)
case Project(Seq(Alias(g: Generator, name)), child) =>
Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil)
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
Generate(g, join = false, outer = false, child, qualifier = None, names, Nil)

object ResolveGenerate extends Rule[LogicalPlan] {
// Construct the output attributes for the generator,
// The output attribute names can be either specified or
// auto generated.
private def makeGeneratorOutput(
generator: Generator,
attributeNames: Seq[String],
qualifier: Option[String]): Array[Attribute] = {
val elementTypes = generator.elementTypes

val raw = if (attributeNames.size == elementTypes.size) { {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
} else { {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
} => :: Nil))).getOrElse(raw).toArray[Attribute]

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case p: Generate if p.resolved == false =>
// if the generator output names are not specified, we will use the default ones.
val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier)
p.generator, p.join, p.outer, p.child, p.qualifier,, gOutput)


Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ trait CheckAnalysis {
throw new AnalysisException(msg)

def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).length >= 1

def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
Expand Down Expand Up @@ -107,6 +113,12 @@ trait CheckAnalysis {
s"unresolved operator ${operator.simpleString}")

case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${",")}""".stripMargin)

case _ => // Analysis successful!
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,13 @@ package object dsl {
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): LogicalPlan =
Generate(generator, join, outer, None, logicalPlan)
alias: Option[String] = None): Generate =
Generate(generator, join, outer, logicalPlan, alias)

def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,27 @@ abstract class Generator extends Expression {

override type EvaluatedType = TraversableOnce[Row]

override lazy val dataType =
ArrayType(StructType( => StructField(, a.dataType, a.nullable, a.metadata))))
override def dataType: DataType = ???

override def nullable: Boolean = false

* Should be overridden by specific generators. Called only once for each instance to ensure
* that rule application does not change the output schema of a generator.
* The output element data types in structure of Seq[(DataType, Nullable)]
protected def makeOutput(): Seq[Attribute]

private var _output: Seq[Attribute] = null

def output: Seq[Attribute] = {
if (_output == null) {
_output = makeOutput()
def elementTypes: Seq[(DataType, Boolean)]

/** Should be implemented by child classes to perform specific Generators. */
override def eval(input: Row): TraversableOnce[Row]

/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
val copy = super.makeCopy(newArgs)
copy._output = _output

* A generator that produces its output using the provided lambda function.
case class UserDefinedGenerator(
schema: Seq[Attribute],
elementTypes: Seq[(DataType, Boolean)],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{

override protected def makeOutput(): Seq[Attribute] = schema
extends Generator {

override def eval(input: Row): TraversableOnce[Row] = {
val inputRow = new InterpretedProjection(children)
Expand All @@ -95,30 +75,18 @@ case class UserDefinedGenerator(
* Given an input array produces a sequence of rows for each value in the array.
case class Explode(attributeNames: Seq[String], child: Expression)
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {

override lazy val resolved =
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])

private lazy val elementTypes = child.dataType match {
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil

// TODO: Move this pattern into Generator.
protected def makeOutput() =
if (attributeNames.size == elementTypes.size) { {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
} else { {
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()

override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_, _) =>
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {

override type EvaluatedType = Any
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]

override def eval(input: Row): Any = child.eval(input)

Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition,
generate @ Generate(generator, join, outer, alias, grandChild)) =>
case filter @ Filter(condition, g: Generate) =>
// Predicates that reference attributes produced by the `Generate` operator cannot
// be pushed below the operator.
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
conjunct => conjunct.references subsetOf grandChild.outputSet
conjunct => conjunct.references subsetOf g.child.outputSet
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput)
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,41 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
* @param alias when set, this string is applied to the schema of the output of the transformation
* as a qualifier.
* @param child Children logical plan node
* @param qualifier Qualifier for the attributes of generator(UDTF)
* @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if
* leave as default (empty)
* @param gOutput The output of Generator.
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
alias: Option[String],
child: LogicalPlan)
child: LogicalPlan,
qualifier: Option[String] = None,
attributeNames: Seq[String] = Nil,
gOutput: Seq[Attribute] = Nil)
extends UnaryNode {

protected def generatorOutput: Seq[Attribute] = {
val output = alias
.map(a => :: Nil)))
if (join && outer) {
} else {
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
attributeNames.length > 0 && == attributeNames

override def output: Seq[Attribute] =
if (join) child.output ++ generatorOutput else generatorOutput
// we don't want the gOutput to be taken as part of the expressions
// as that will cause exceptions like unresolved attributes etc.
override def expressions: Seq[Expression] = generator :: Nil

def output: Seq[Attribute] = {
if (join) child.output ++ gOutput else gOutput

case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {

assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)

val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)

assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
Original file line number Diff line number Diff line change
Expand Up @@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a > 6))
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = {
.where(('b >= 5) && ('a > 6))
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
.generate(Explode('c_arr), true, false, Some("arr")).analyze

comparePlans(optimized, correctAnswer)

test("generate: part of conjuncts referenced generated column") {
val generator = Explode(Seq("c"), 'c_arr)
val generator = Explode('c_arr)
val originalQuery = {
.generate(generator, true, false, Some("arr"))
Expand Down Expand Up @@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('c > 6) || ('b > 5)).analyze
val optimized = Optimize(originalQuery)
17 changes: 12 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,15 @@ class DataFrame private[sql](
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes

val elementTypes = { attr => (attr.dataType, attr.nullable) }
val names =

val rowFunction =
f.andThen(, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction,
val generator = UserDefinedGenerator(elementTypes, rowFunction,

Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)

Expand All @@ -733,12 +736,16 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
// TODO handle the metadata?
val elementTypes = { attr => (attr.dataType, attr.nullable) }
val names =

def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)

Expand Down

