Skip to content

Commit

Permalink
opt: fix some aggregate scoping issues
Browse files Browse the repository at this point in the history
Prior to this commit, some aggregate functions were either incorrectly
rejected or incorrectly accepted when they were scoped at a higher
level than their position in the query. For example, aggregate functions
are not normally allowed in WHERE, but if the aggregate is actually scoped
at a higher level, then the aggregate should be allowed. Prior to this
commit, these aggregate functions were rejected and caused an error.

This commit fixes the issue by validating the context of the aggregate's
scope rather than the aggregate's position in the query. In order to
avoid adding another field to the scope struct, this commit re-uses
the existing `context` field which was previously only used for error
messages. To make comparisons more efficient, the field is now an enum
rather than a string.

Fixes cockroachdb#44724
Fixes cockroachdb#45838
Fixes cockroachdb#30652

Release justification: This bug fix is a low risk, high benefit change
to existing functionality, since it fixes internal errors and increases
compatibility with Postgres.

Release note (bug fix): Fixed an internal error that could occur when
an aggregate inside the right-hand side of a LATERAL join was scoped at
the level of the left-hand side.

Release note (bug fix): Fixed an error that incorrectly occurred when
an aggregate was used inside the WHERE or ON clause of a subquery but
was scoped at an outer level of the query.
  • Loading branch information
rytaft committed Mar 25, 2020
1 parent 930c0fa commit b6320d4
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pkg/sql/logictest/testdata/logic_test/join
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ NULL NULL
query error generator functions are not allowed in ON
SELECT * FROM foo JOIN bar ON generate_series(0, 1) < 2

query error aggregate functions are not allowed in ON
query error aggregate functions are not allowed in JOIN conditions
SELECT * FROM foo JOIN bar ON max(foo.c) < 2

# Regression test for #44029 (outer join on two single-row clauses, with two
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ func (b *Builder) buildAlterTableSplit(split *tree.Split, inScope *scope) (outSc
// Build the expiration scalar.
var expiration opt.ScalarExpr
if split.ExpireExpr != nil {
emptyScope.context = "ALTER TABLE SPLIT AT"
emptyScope.context = exprKindAlterTableSplitAt
// We need to save and restore the previous value of the field in
// semaCtx in case we are recursively called within a subquery
// context.
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)
b.semaCtx.Properties.Require(emptyScope.context, tree.RejectSpecial)
b.semaCtx.Properties.Require(emptyScope.context.String(), tree.RejectSpecial)

texpr := emptyScope.resolveType(split.ExpireExpr, types.String)
expiration = b.buildScalar(texpr, emptyScope, nil /* outScope */, nil /* outCol */, nil /* colRefs */)
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ func (b *Builder) analyzeDistinctOnArgs(
// semaCtx in case we are recursively called within a subquery
// context.
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)
b.semaCtx.Properties.Require("DISTINCT ON", tree.RejectGenerators)
inScope.context = "DISTINCT ON"
b.semaCtx.Properties.Require(exprKindDistinctOn.String(), tree.RejectGenerators)
inScope.context = exprKindDistinctOn

for i := range distinctOn {
b.analyzeExtraArgument(distinctOn[i], inScope, projectionsScope, distinctOnScope)
Expand Down
10 changes: 6 additions & 4 deletions pkg/sql/opt/optbuilder/groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,10 @@ func (b *Builder) analyzeHaving(having *tree.Where, fromScope *scope) tree.Typed
// We need to save and restore the previous value of the field in semaCtx
// in case we are recursively called within a subquery context.
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)
b.semaCtx.Properties.Require("HAVING", tree.RejectWindowApplications|tree.RejectGenerators)
fromScope.context = "HAVING"
b.semaCtx.Properties.Require(
exprKindHaving.String(), tree.RejectWindowApplications|tree.RejectGenerators,
)
fromScope.context = exprKindHaving
return fromScope.resolveAndRequireType(having.Expr, types.Bool)
}

Expand Down Expand Up @@ -583,8 +585,8 @@ func (b *Builder) buildGrouping(
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)

// Make sure the GROUP BY columns have no special functions.
b.semaCtx.Properties.Require("GROUP BY", tree.RejectSpecial)
fromScope.context = "GROUP BY"
b.semaCtx.Properties.Require(exprKindGroupBy.String(), tree.RejectSpecial)
fromScope.context = exprKindGroupBy

// Resolve types, expand stars, and flatten tuples.
exprs := b.expandStarAndResolveType(groupBy, fromScope)
Expand Down
6 changes: 4 additions & 2 deletions pkg/sql/opt/optbuilder/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ func (b *Builder) buildJoin(join *tree.JoinTableExpr, inScope *scope) (outScope
var filters memo.FiltersExpr
if on, ok := cond.(*tree.OnJoinCond); ok {
// Do not allow special functions in the ON clause.
b.semaCtx.Properties.Require("ON", tree.RejectSpecial)
outScope.context = "ON"
b.semaCtx.Properties.Require(
exprKindOn.String(), tree.RejectGenerators|tree.RejectWindowApplications,
)
outScope.context = exprKindOn
filter := b.buildScalar(
outScope.resolveAndRequireType(on.Expr, types.Bool), outScope, nil, nil, nil,
)
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ func (b *Builder) buildLimit(limit *tree.Limit, parentScope, inScope *scope) {
if limit.Offset != nil {
input := inScope.expr.(memo.RelExpr)
offset := b.resolveAndBuildScalar(
limit.Offset, types.Int, "OFFSET", tree.RejectSpecial, parentScope,
limit.Offset, types.Int, exprKindOffset, tree.RejectSpecial, parentScope,
)
inScope.expr = b.factory.ConstructOffset(input, offset, inScope.makeOrderingChoice())
}
if limit.Count != nil {
input := inScope.expr.(memo.RelExpr)
limit := b.resolveAndBuildScalar(
limit.Count, types.Int, "LIMIT", tree.RejectSpecial, parentScope,
limit.Count, types.Int, exprKindLimit, tree.RejectSpecial, parentScope,
)
inScope.expr = b.factory.ConstructLimit(input, limit, inScope.makeOrderingChoice())
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/opt/optbuilder/orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func (b *Builder) analyzeOrderBy(
// semaCtx in case we are recursively called within a subquery
// context.
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)
b.semaCtx.Properties.Require("ORDER BY", tree.RejectGenerators)
inScope.context = "ORDER BY"
b.semaCtx.Properties.Require(exprKindOrderBy.String(), tree.RejectGenerators)
inScope.context = exprKindOrderBy

for i := range orderBy {
b.analyzeOrderByArg(orderBy[i], inScope, projectionsScope, orderByScope)
Expand Down Expand Up @@ -231,12 +231,12 @@ func (b *Builder) analyzeExtraArgument(
// e.g. SELECT a, b FROM t ORDER by a+b

// First, deal with projection aliases.
idx := colIdxByProjectionAlias(expr, inScope.context, projectionsScope)
idx := colIdxByProjectionAlias(expr, inScope.context.String(), projectionsScope)

// If the expression does not refer to an alias, deal with
// column ordinals.
if idx == -1 {
idx = colIndex(len(projectionsScope.cols), expr, inScope.context)
idx = colIndex(len(projectionsScope.cols), expr, inScope.context.String())
}

var exprs tree.TypedExprs
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/opt/optbuilder/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func (b *Builder) analyzeProjectionList(
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)
defer func(replaceSRFs bool) { inScope.replaceSRFs = replaceSRFs }(inScope.replaceSRFs)

b.semaCtx.Properties.Require("SELECT", tree.RejectNestedGenerators)
inScope.context = "SELECT"
b.semaCtx.Properties.Require(exprKindSelect.String(), tree.RejectNestedGenerators)
inScope.context = exprKindSelect
inScope.replaceSRFs = true

b.analyzeSelectList(selects, desiredTypes, inScope, outScope)
Expand All @@ -89,8 +89,8 @@ func (b *Builder) analyzeReturningList(
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)

// Ensure there are no special functions in the RETURNING clause.
b.semaCtx.Properties.Require("RETURNING", tree.RejectSpecial)
inScope.context = "RETURNING"
b.semaCtx.Properties.Require(exprKindReturning.String(), tree.RejectSpecial)
inScope.context = exprKindReturning

b.analyzeSelectList(tree.SelectExprs(returning), desiredTypes, inScope, outScope)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func (b *Builder) checkSubqueryOuterCols(
aggCols := inScope.groupby.aggregateResultCols()
for i := range aggCols {
if subqueryOuterCols.Contains(aggCols[i].id) {
panic(tree.NewInvalidFunctionUsageError(tree.AggregateClass, inScope.context))
panic(tree.NewInvalidFunctionUsageError(tree.AggregateClass, inScope.context.String()))
}
}
}
Expand Down
85 changes: 79 additions & 6 deletions pkg/sql/opt/optbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ type scope struct {
ctes map[string]*cteSource

// context is the current context in the SQL query (e.g., "SELECT" or
// "HAVING"). It is used for error messages.
context string
// "HAVING"). It is used for error messages and to identify scoping errors
// (e.g., aggregates are not allowed in the FROM clause of their own query
// level).
context exprKind
}

// cteSource represents a CTE in the given query.
Expand All @@ -112,6 +114,57 @@ type cteSource struct {
id opt.WithID
}

// exprKind is used to represent the kind of the current expression in the
// SQL query.
type exprKind int8

const (
exprKindNone exprKind = iota
exprKindAlterTableSplitAt
exprKindDistinctOn
exprKindFrom
exprKindGroupBy
exprKindHaving
exprKindLateralJoin
exprKindLimit
exprKindOffset
exprKindOn
exprKindOrderBy
exprKindReturning
exprKindSelect
exprKindValues
exprKindWhere
exprKindWindowFrameStart
exprKindWindowFrameEnd
)

var exprKindName = [...]string{
exprKindNone: "",
exprKindAlterTableSplitAt: "ALTER TABLE SPLIT AT",
exprKindDistinctOn: "DISTINCT ON",
exprKindFrom: "FROM",
exprKindGroupBy: "GROUP BY",
exprKindHaving: "HAVING",
exprKindLateralJoin: "LATERAL JOIN",
exprKindLimit: "LIMIT",
exprKindOffset: "OFFSET",
exprKindOn: "ON",
exprKindOrderBy: "ORDER BY",
exprKindReturning: "RETURNING",
exprKindSelect: "SELECT",
exprKindValues: "VALUES",
exprKindWhere: "WHERE",
exprKindWindowFrameStart: "WINDOW FRAME START",
exprKindWindowFrameEnd: "WINDOW FRAME END",
}

func (k exprKind) String() string {
if k < 0 || k > exprKind(len(exprKindName)-1) {
return fmt.Sprintf("exprKind(%d)", k)
}
return exprKindName[k]
}

// initGrouping initializes the groupby information for this scope.
func (s *scope) initGrouping() {
if s.groupby != nil {
Expand Down Expand Up @@ -329,7 +382,7 @@ func (s *scope) resolveType(expr tree.Expr, desired *types.T) tree.TypedExpr {
// desired type.
func (s *scope) resolveAndRequireType(expr tree.Expr, desired *types.T) tree.TypedExpr {
expr = s.walkExprTree(expr)
texpr, err := tree.TypeCheckAndRequire(expr, s.builder.semaCtx, desired, s.context)
texpr, err := tree.TypeCheckAndRequire(expr, s.builder.semaCtx, desired, s.context.String())
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -510,6 +563,7 @@ func (s *scope) endAggFunc(cols opt.ColSet) (g *groupby) {

for curr := s; curr != nil; curr = curr.parent {
if cols.Len() == 0 || cols.Intersects(curr.colSet()) {
curr.verifyAggregateContext()
if curr.groupby == nil {
curr.initGrouping()
}
Expand All @@ -520,6 +574,25 @@ func (s *scope) endAggFunc(cols opt.ColSet) (g *groupby) {
panic(errors.AssertionFailedf("aggregate function is not allowed in this context"))
}

// verifyAggregateContext checks that the current scope is allowed to contain
// aggregate functions.
func (s *scope) verifyAggregateContext() {
switch s.context {
case exprKindLateralJoin:
panic(pgerror.Newf(pgcode.Grouping,
"aggregate functions are not allowed in FROM clause of their own query level",
))

case exprKindOn:
panic(pgerror.Newf(pgcode.Grouping,
"aggregate functions are not allowed in JOIN conditions",
))

case exprKindWhere:
panic(tree.NewInvalidFunctionUsageError(tree.AggregateClass, s.context.String()))
}
}

// scope implements the tree.Visitor interface so that it can walk through
// a tree.Expr tree, perform name resolution, and replace unresolved column
// names with a scopeColumn. The info stored in scopeColumn is necessary for
Expand Down Expand Up @@ -858,7 +931,7 @@ func (s *scope) replaceSRF(f *tree.FuncExpr, def *tree.FunctionDefinition) *srf
// context.
defer s.builder.semaCtx.Properties.Restore(s.builder.semaCtx.Properties)

s.builder.semaCtx.Properties.Require(s.context,
s.builder.semaCtx.Properties.Require(s.context.String(),
tree.RejectAggregates|tree.RejectWindowApplications|tree.RejectNestedGenerators)

expr := f.Walk(s)
Expand Down Expand Up @@ -1104,13 +1177,13 @@ func analyzeWindowFrame(s *scope, windowDef *tree.WindowDef) error {
}
if startBound != nil && startBound.OffsetExpr != nil {
oldContext := s.context
s.context = "WINDOW FRAME START"
s.context = exprKindWindowFrameStart
startBound.OffsetExpr = s.resolveAndRequireType(startBound.OffsetExpr, requiredType)
s.context = oldContext
}
if endBound != nil && endBound.OffsetExpr != nil {
oldContext := s.context
s.context = "WINDOW FRAME END"
s.context = exprKindWindowFrameEnd
endBound.OffsetExpr = s.resolveAndRequireType(endBound.OffsetExpr, requiredType)
s.context = oldContext
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/sql/opt/optbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,13 @@ func (b *Builder) buildWhere(where *tree.Where, inScope *scope) {
return
}

filter := b.resolveAndBuildScalar(where.Expr, types.Bool, "WHERE", tree.RejectSpecial, inScope)
filter := b.resolveAndBuildScalar(
where.Expr,
types.Bool,
exprKindWhere,
tree.RejectGenerators|tree.RejectWindowApplications,
inScope,
)

// Wrap the filter in a FiltersOp.
inScope.expr = b.factory.ConstructSelect(
Expand Down Expand Up @@ -1017,6 +1023,7 @@ func (b *Builder) buildFromWithLateral(tables tree.TableExprs, inScope *scope) (
// have been built already.
if b.exprIsLateral(tables[i]) {
scope = outScope
scope.context = exprKindLateralJoin
}
tableScope := b.buildDataSource(tables[i], nil /* indexFlags */, scope)

Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/srfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func (b *Builder) buildZip(exprs tree.Exprs, inScope *scope) (outScope *scope) {
// semaCtx in case we are recursively called within a subquery
// context.
defer b.semaCtx.Properties.Restore(b.semaCtx.Properties)
b.semaCtx.Properties.Require("FROM",
b.semaCtx.Properties.Require(exprKindFrom.String(),
tree.RejectAggregates|tree.RejectWindowApplications|tree.RejectNestedGenerators)
inScope.context = "FROM"
inScope.context = exprKindFrom

// Build each of the provided expressions.
zip := make(memo.ZipExpr, len(exprs))
Expand Down
Loading

0 comments on commit b6320d4

Please sign in to comment.