Skip to content

Commit

Permalink
[SPARK-21896][SQL] Fix StackOverflow caused by window functions insid…
Browse files Browse the repository at this point in the history
…e aggregate functions

## What changes were proposed in this pull request?

This PR explicitly prohibits window functions inside aggregates. Currently, this will cause StackOverflow during analysis. See PR #19193 for previous discussion.

## How was this patch tested?

This PR comes with a dedicated unit test.

Author: aokolnychyi <[email protected]>

Closes #21473 from aokolnychyi/fix-stackoverflow-window-funcs.
  • Loading branch information
aokolnychyi authored and cloud-fan committed Jun 4, 2018
1 parent 0be5aa2 commit 7297ae0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1744,10 +1744,10 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
exprs.exists(hasWindowFunction)

private def hasWindowFunction(expr: NamedExpression): Boolean = {
private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
Expand Down Expand Up @@ -1830,6 +1830,10 @@ class Analyzer(
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)

case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
failAnalysis("It is not allowed to use a window function inside an aggregate " +
"function. Please use the inner window function in a sub-query.")

// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql

import scala.util.Random

import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.scalatest.Matchers.the

import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-21896: Window functions inside aggregate functions") {
def checkWindowError(df: => DataFrame): Unit = {
val thrownException = the [AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("not allowed to use a window function"))
}

checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a)))))
checkWindowError(
testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3))
checkAnswer(
testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3),
Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil)

checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2"))
checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a"))
checkWindowError(
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
checkAnswer(
sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"),
Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)
}

}

0 comments on commit 7297ae0

Please sign in to comment.