Skip to content

Commit

Permalink
address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Mar 4, 2016
1 parent 6f609fb commit 9eaca51
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Project =>
projectToSQL(p, isDistinct = false)

case a @ Aggregate(_, _, e @ Expand(_, _, p: Project))
if sameOutput(e.output,
p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) =>
case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
groupingSetToSQL(a, e, p)

case p: Aggregate =>
Expand Down Expand Up @@ -208,10 +206,6 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
throw new UnsupportedOperationException(s"unsupported plan $node")
}

private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))

/**
* Turns a bunch of string segments into a single string and separate each segment by a space.
* The segments are trimmed so only a single space appears in the separation.
Expand Down Expand Up @@ -242,6 +236,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
)
}

private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))

private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
assert(a.child == e && e.child == p)
a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
}

private def groupingSetToSQL(
agg: Aggregate,
expand: Expand,
Expand All @@ -253,16 +257,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi

val numOriginalOutput = project.child.output.length
// Assumption: Aggregate's groupingExpressions is composed of
// 1) the group by attributes' aliases
// 1) the attributes of aliased group by expressions
// 2) gid, which is always the last one
val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
// Assumption: Project's projectList is composed of
// 1) the original output (Project's child.output),
// 2) the aliases of the original group by attributes, which could be expressions
// 2) the aliased group by expressions.
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")

// a map from the alias name to the original group by expresions/attributes
// a map from group by attributes to the original group by expressions.
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))

val groupingSet = expand.projections.map { project =>
Expand Down

0 comments on commit 9eaca51

Please sign in to comment.