Skip to content

Commit

Permalink
resolve comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Mar 3, 2016
1 parent ae768fe commit 6cea658
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Project =>
projectToSQL(p, isDistinct = false)

case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) =>
case a @ Aggregate(_, _, e @ Expand(_, _, p: Project))
if sameOutput(e.output,
p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) =>
groupingSetToSQL(a, e, p)

case p: Aggregate =>
Expand Down Expand Up @@ -185,6 +187,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
throw new UnsupportedOperationException(s"unsupported plan $node")
}

private def sameOutput(left: Seq[Attribute], right: Seq[Attribute]): Boolean =
left.forall(a => right.exists(_.semanticEquals(a))) &&
right.forall(a => left.exists(_.semanticEquals(a)))

/**
* Turns a bunch of string segments into a single string and separate each segment by a space.
* The segments are trimmed so only a single space appears in the separation.
Expand Down Expand Up @@ -231,6 +237,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")

val groupingSet = expand.projections.map { project =>
// Assumption: expand.projections are composed of
// 1) the original output (project.child.output),
// 2) group by attributes(or null literal)
// 3) gid, which is always the last one in each project
project.dropRight(1).collect {
case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
}
Expand All @@ -241,9 +251,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi

val aggExprs = plan.aggregateExpressions.map { case expr =>
expr.transformDown {
case a @ Alias(child: AttributeReference, name) if child eq gid =>
// grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
Alias(GroupingID(Nil), name)()
// grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
case ar: AttributeReference if ar eq gid => GroupingID(Nil)
case a @ Alias(_ @ Cast(BitwiseAnd(
ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)),
Literal(1, IntegerType)), ByteType), name) if ar == gid =>
Expand All @@ -255,11 +264,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
} else {
throw new UnsupportedOperationException(s"unsupported operator $a")
}
case a @ Alias(child: AttributeReference, _) if groupByAttrMap.contains(child) =>
groupByAttrMap(child)
case a @ Alias(ar: AttributeReference, _) if groupByAttrMap.contains(ar) =>
groupByAttrMap(ar)
case ar: AttributeReference if groupByAttrMap.contains(ar) =>
groupByAttrMap(ar)
case o => o
}
}

Expand Down

0 comments on commit 6cea658

Please sign in to comment.