diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 0d9a68342c643..f6d657abc96f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -107,9 +107,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) - case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) - if sameOutput(e.output, - p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) => + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => groupingSetToSQL(a, e, p) case p: Aggregate => @@ -208,10 +206,6 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi throw new UnsupportedOperationException(s"unsupported plan $node") } - private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - /** * Turns a bunch of string segments into a single string and separate each segment by a space. * The segments are trimmed so only a single space appears in the separation. @@ -242,6 +236,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { + assert(a.child == e && e.child == p) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && + sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + } + private def groupingSetToSQL( agg: Aggregate, expand: Expand, @@ -253,16 +257,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of - // 1) the group by attributes' aliases + // 1) the attributes of aliased group by expressions // 2) gid, which is always the last one val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) // Assumption: Project's projectList is composed of // 1) the original output (Project's child.output), - // 2) the aliases of the original group by attributes, which could be expressions + // 2) the aliased group by expressions. val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") - // a map from the alias name to the original group by expresions/attributes + // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) val groupingSet = expand.projections.map { project =>