Skip to content

Commit

Permalink
[SPARK-18863][SQL] Output non-aggregate expressions without GROUP BY …
Browse files Browse the repository at this point in the history
…in a subquery does not yield an error

## What changes were proposed in this pull request?
This PR will report proper error messages when a subquery expression contain an invalid plan. This problem is fixed by calling CheckAnalysis for the plan inside a subquery.

## How was this patch tested?
Existing tests and two new test cases on 2 forms of subquery, namely, scalar subquery and in/exists subquery.

````
-- TC 01.01
-- The column t2b in the SELECT of the subquery is invalid
-- because it is neither an aggregate function nor a GROUP BY column.
select t1a, t2b
from   t1, t2
where  t1b = t2c
and    t2b = (select max(avg)
              from   (select   t2b, avg(t2b) avg
                      from     t2
                      where    t2a = t1.t1b
                     )
             )
;

-- TC 01.02
-- Invalid due to the column t2b not part of the output from table t2.
select *
from   t1
where  t1a in (select   min(t2a)
               from     t2
               group by t2c
               having   t2c in (select   max(t3c)
                                from     t3
                                group by t3b
                                having   t3b > t2b ))
;
````

Author: Nattavut Sutyanyong <[email protected]>

Closes #16572 from nsyca/18863.
  • Loading branch information
nsyca authored and hvanhovell committed Jan 25, 2017
1 parent 0e821ec commit f1ddca5
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,66 +117,72 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}
case s @ ScalarSubquery(query, conditions, _)

case s @ ScalarSubquery(query, conditions, _) =>
// If no correlation, the output must be exactly one column
if (conditions.isEmpty && query.output.size != 1) =>
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
}
else if (conditions.nonEmpty) {
// Collect the columns from the subquery for further checking.
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)

def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates containing exactly one aggregate expression.
// The analyzer has already checked that subquery contained only one output column,
// and added all the grouping expressions to the aggregate.
val aggregates = agg.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
}

case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>

// Collect the columns from the subquery for further checking.
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)

def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates which contain exactly one aggregate expressions.
// The analyzer has already checked that subquery contained only one output column,
// and added all the grouping expressions to the aggregate.
val aggregates = agg.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
if (invalidCols.nonEmpty) {
failAnalysis(
"A GROUP BY clause in a scalar correlated subquery " +
"cannot contain non-correlated columns: " +
invalidCols.mkString(","))
}
}

// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
if (invalidCols.nonEmpty) {
failAnalysis(
"A GROUP BY clause in a scalar correlated subquery " +
"cannot contain non-correlated columns: " +
invalidCols.mkString(","))
}
}
// Skip subquery aliases added by the Analyzer and the SQLBuilder.
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
case p: Project =>
// SPARK-18814: Map any aliases to their AttributeReference children
// for the checking in the Aggregate operators below this Project.
subqueryColumns = subqueryColumns.map {
xs => p.projectList.collectFirst {
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
child
}.getOrElse(xs)
}

// Skip subquery aliases added by the Analyzer and the SQLBuilder.
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
case p: Project =>
// SPARK-18814: Map any aliases to their AttributeReference children
// for the checking in the Aggregate operators below this Project.
subqueryColumns = subqueryColumns.map {
xs => p.projectList.collectFirst {
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
child
}.getOrElse(xs)
}
cleanQuery(p.child)
case child => child
}

cleanQuery(p.child)
case child => child
cleanQuery(query) match {
case a: Aggregate => checkAggregate(a)
case Filter(_, a: Aggregate) => checkAggregate(a)
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
}
}
checkAnalysis(query)
s

cleanQuery(query) match {
case a: Aggregate => checkAggregate(a)
case Filter(_, a: Aggregate) => checkAggregate(a)
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
}
case s: SubqueryExpression =>
checkAnalysis(s.plan)
s
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
-- The test file contains negative test cases
-- of invalid queries where error messages are expected.

create temporary view t1 as select * from values
(1, 2, 3)
as t1(t1a, t1b, t1c);

create temporary view t2 as select * from values
(1, 0, 1)
as t2(t2a, t2b, t2c);

create temporary view t3 as select * from values
(3, 1, 2)
as t3(t3a, t3b, t3c);

-- TC 01.01
-- The column t2b in the SELECT of the subquery is invalid
-- because it is neither an aggregate function nor a GROUP BY column.
select t1a, t2b
from t1, t2
where t1b = t2c
and t2b = (select max(avg)
from (select t2b, avg(t2b) avg
from t2
where t2a = t1.t1b
)
)
;

-- TC 01.02
-- Invalid due to the column t2b not part of the output from table t2.
select *
from t1
where t1a in (select min(t2a)
from t2
group by t2c
having t2c in (select max(t3c)
from t3
group by t3b
having t3b > t2b ))
;

Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 5


-- !query 0
create temporary view t1 as select * from values
(1, 2, 3)
as t1(t1a, t1b, t1c)
-- !query 0 schema
struct<>
-- !query 0 output



-- !query 1
create temporary view t2 as select * from values
(1, 0, 1)
as t2(t2a, t2b, t2c)
-- !query 1 schema
struct<>
-- !query 1 output



-- !query 2
create temporary view t3 as select * from values
(3, 1, 2)
as t3(t3a, t3b, t3c)
-- !query 2 schema
struct<>
-- !query 2 output



-- !query 3
select t1a, t2b
from t1, t2
where t1b = t2c
and t2b = (select max(avg)
from (select t2b, avg(t2b) avg
from t2
where t2a = t1.t1b
)
)
-- !query 3 schema
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;


-- !query 4
select *
from t1
where t1a in (select min(t2a)
from t2
group by t2c
having t2c in (select max(t3c)
from t3
group by t3b
having t3b > t2b ))
-- !query 4 schema
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)];
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
} catch {
case a: AnalysisException if a.plan.nonEmpty =>
// Do not output the logical plan tree which contains expression IDs.
(StructType(Seq.empty), Seq(a.getClass.getName, a.getSimpleMessage))
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
(StructType(Seq.empty),
Seq(a.getClass.getName, a.getSimpleMessage.replaceAll("#\\d+", "#x")))
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))
Expand Down

0 comments on commit f1ddca5

Please sign in to comment.