Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gen4: Distinct #8479

Merged
merged 14 commits into from
Jul 19, 2021
Merged
1 change: 1 addition & 0 deletions go/mysql/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ var stateToMysqlCode = map[vterrors.State]struct {
vterrors.WrongNumberOfColumnsInSelect: {num: ERWrongNumberOfColumnsInSelect, state: SSWrongNumberOfColumns},
vterrors.WrongTypeForVar: {num: ERWrongTypeForVar, state: SSClientError},
vterrors.WrongValueForVar: {num: ERWrongValueForVar, state: SSClientError},
vterrors.WrongFieldWithGroup: {num: ERWrongFieldWithGroup, state: SSClientError},
vterrors.ServerNotAvailable: {num: ERServerIsntAvailable, state: SSNetError},
vterrors.CantDoThisInTransaction: {num: ERCantDoThisDuringAnTransaction, state: SSCantDoThisDuringAnTransaction},
vterrors.RequiresPrimaryKey: {num: ERRequiresPrimaryKey, state: SSClientError},
Expand Down
19 changes: 12 additions & 7 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1371,17 +1371,22 @@ func ToString(exprs []TableExpr) string {
func ContainsAggregation(e Expr) bool {
hasAggregates := false
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *FuncExpr:
if node.IsAggregate() {
hasAggregates = true
return false, nil
}
case *GroupConcatExpr:
if IsAggregation(node) {
hasAggregates = true
return false, nil
}
return true, nil
}, e)
return hasAggregates
}

// IsAggregation returns true if the node is an aggregation expression
func IsAggregation(node SQLNode) bool {
switch node := node.(type) {
case *FuncExpr:
return node.IsAggregate()
case *GroupConcatExpr:
return true
}
return false
}
1 change: 1 addition & 0 deletions go/vt/vterrors/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
NonUniqTable
NonUpdateableTable
SyntaxError
WrongFieldWithGroup
WrongGroupField
WrongTypeForVar
WrongValueForVar
Expand Down
111 changes: 87 additions & 24 deletions go/vt/vtgate/planbuilder/abstract/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type (
// If you change the contents here, please update the toString() method
SelectExprs []SelectExpr
HasAggr bool
Distinct bool
GroupByExprs []GroupBy
OrderExprs []OrderBy
CanPushDownSorting bool
Expand All @@ -58,39 +59,28 @@ type (

// CreateQPFromSelect created the QueryProjection for the input *sqlparser.Select
func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
qp := &QueryProjection{}
qp := &QueryProjection{
Distinct: sel.Distinct,
}

hasNonAggr := false
for _, selExp := range sel.SelectExprs {
exp, ok := selExp.(*sqlparser.AliasedExpr)
if !ok {
return nil, semantics.Gen4NotSupportedF("%T in select list", selExp)
}
fExpr, ok := exp.Expr.(*sqlparser.FuncExpr)
if ok && fExpr.IsAggregate() {
if len(fExpr.Exprs) != 1 {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr))
}
if fExpr.Distinct {
return nil, semantics.Gen4NotSupportedF("distinct aggregation")
}
qp.HasAggr = true
qp.SelectExprs = append(qp.SelectExprs, SelectExpr{
Col: exp,
Aggr: true,
})
continue
}

if err := checkForInvalidAggregations(exp); err != nil {
return nil, err
}
col := SelectExpr{
Col: exp,
}
if sqlparser.ContainsAggregation(exp.Expr) {
return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression")
col.Aggr = true
qp.HasAggr = true
}
hasNonAggr = true
qp.SelectExprs = append(qp.SelectExprs, SelectExpr{Col: exp})
}

if hasNonAggr && qp.HasAggr && sel.GroupBy == nil {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.MixOfGroupFuncAndFields, "Mixing of aggregation and non-aggregation columns is not allowed if there is no GROUP BY clause")
qp.SelectExprs = append(qp.SelectExprs, col)
}

for _, group := range sel.GroupBy {
Expand All @@ -104,6 +94,18 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
qp.GroupByExprs = append(qp.GroupByExprs, GroupBy{Inner: expr, WeightStrExpr: weightStrExpr})
}

if qp.HasAggr {
expr := qp.getNonAggrExprNotMatchingGroupByExprs()
// if we have aggregation functions, non aggregating columns and GROUP BY,
// the non-aggregating expressions must all be listed in the GROUP BY list
if expr != nil {
if len(qp.GroupByExprs) > 0 {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongFieldWithGroup, "Expression of SELECT list is not in GROUP BY clause and contains nonaggregated column '%s' which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_mode=only_full_group_by", sqlparser.String(expr))
}
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.MixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by", sqlparser.String(expr))
}
}

canPushDownSorting := true
for _, order := range sel.OrderBy {
expr, weightStrExpr, err := qp.getSimplifiedExpr(order.Expr, "order clause")
Expand All @@ -119,11 +121,47 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
})
canPushDownSorting = canPushDownSorting && !sqlparser.ContainsAggregation(weightStrExpr)
}
if qp.Distinct && !qp.HasAggr {
qp.GroupByExprs = nil
}
qp.CanPushDownSorting = canPushDownSorting

return qp, nil
}

func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error {
return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
fExpr, ok := node.(*sqlparser.FuncExpr)
if ok && fExpr.IsAggregate() {
if len(fExpr.Exprs) != 1 {
return false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr))
}
if fExpr.Distinct {
return false, semantics.Gen4NotSupportedF("distinct aggregation")
}
}
return true, nil
}, exp.Expr)
}

func (qp *QueryProjection) getNonAggrExprNotMatchingGroupByExprs() sqlparser.Expr {
for _, expr := range qp.SelectExprs {
if expr.Aggr {
continue
}
isGroupByOk := false
for _, groupByExpr := range qp.GroupByExprs {
if sqlparser.EqualsExpr(groupByExpr.WeightStrExpr, expr.Col.Expr) {
isGroupByOk = true
break
}
}
if !isGroupByOk {
return expr.Col.Expr
}
}
return nil
}

// getSimplifiedExpr takes an expression used in ORDER BY or GROUP BY, which can reference both aliased columns and
// column offsets, and returns an expression that is simpler to evaluate
func (qp *QueryProjection) getSimplifiedExpr(e sqlparser.Expr, caller string) (expr sqlparser.Expr, weightStrExpr sqlparser.Expr, err error) {
Expand Down Expand Up @@ -169,11 +207,13 @@ func (qp *QueryProjection) toString() string {
Select []string
Grouping []string
OrderBy []string
Distinct bool
}
out := output{
Select: []string{},
Grouping: []string{},
OrderBy: []string{},
Distinct: qp.NeedsDistinct(),
}

for _, expr := range qp.SelectExprs {
Expand Down Expand Up @@ -204,3 +244,26 @@ func (qp *QueryProjection) toString() string {
func (qp *QueryProjection) NeedsAggregation() bool {
return qp.HasAggr || len(qp.GroupByExprs) > 0
}

func (qp QueryProjection) onlyAggr() bool {
if !qp.HasAggr {
return false
}
for _, expr := range qp.SelectExprs {
if !expr.Aggr {
return false
}
}
return true
}

// NeedsDistinct returns true if the query needs explicit distinct
func (qp *QueryProjection) NeedsDistinct() bool {
if !qp.Distinct {
return false
}
if qp.onlyAggr() && len(qp.GroupByExprs) == 0 {
return false
}
return true
}
42 changes: 33 additions & 9 deletions go/vt/vtgate/planbuilder/abstract/queryprojection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestQP(t *testing.T) {
},
{
sql: "select 1, count(1) from user",
expErr: "Mixing of aggregation and non-aggregation columns is not allowed if there is no GROUP BY clause",
expErr: "In aggregated query without GROUP BY, expression of SELECT list contains nonaggregated column '1'; this is incompatible with sql_mode=only_full_group_by",
},
{
sql: "select max(id) from user",
Expand All @@ -55,13 +55,9 @@ func TestQP(t *testing.T) {
sql: "select max(a, b) from user",
expErr: "aggregate functions take a single argument 'max(a, b)'",
},
{
sql: "select func(max(id)) from user",
expErr: "unsupported: in scatter query: complex aggregate expression",
},
{
sql: "select 1, count(1) from user order by 1",
expErr: "Mixing of aggregation and non-aggregation columns is not allowed if there is no GROUP BY clause",
expErr: "In aggregated query without GROUP BY, expression of SELECT list contains nonaggregated column '1'; this is incompatible with sql_mode=only_full_group_by",
},
{
sql: "select id from user order by col, id, 1",
Expand Down Expand Up @@ -132,7 +128,8 @@ func TestQPSimplifiedExpr(t *testing.T) {
"Grouping": [
"intcol"
],
"OrderBy": []
"OrderBy": [],
"Distinct": false
}`,
},
{
Expand All @@ -147,7 +144,8 @@ func TestQPSimplifiedExpr(t *testing.T) {
"OrderBy": [
"intcol asc",
"textcol asc"
]
],
"Distinct": false
}`,
},
{
Expand All @@ -166,7 +164,33 @@ func TestQPSimplifiedExpr(t *testing.T) {
],
"OrderBy": [
"textcol desc"
]
],
"Distinct": false
}`,
},
{
query: "select distinct col1, col2 from user group by col1, col2",
expected: `
{
"Select": [
"col1",
"col2"
],
"Grouping": [],
"OrderBy": [],
"Distinct": true
}`,
},
{
query: "select distinct count(*) from user",
expected: `
{
"Select": [
"aggr: count(*)"
],
"Grouping": [],
"OrderBy": [],
"Distinct": false
}`,
},
}
Expand Down
6 changes: 1 addition & 5 deletions go/vt/vtgate/planbuilder/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ import (

var _ logicalPlan = (*distinct)(nil)

// limit is the logicalPlan for engine.Limit.
// This gets built if a limit needs to be applied
// after rows are returned from an underlying
// operation. Since a limit is the final operation
// of a SELECT, most pushes are not applicable.
// distinct is the logicalPlan for engine.Distinct.
type distinct struct {
logicalPlanCommon
}
Expand Down
Loading