Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner/spansql): add support for aggregate function calls #8498

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions spanner/spansql/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,16 @@ var keywords = map[string]bool{
// https://cloud.google.com/spanner/docs/functions-and-operators
var funcs = make(map[string]bool)
var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError))
var aggregateFuncs = make(map[string]bool)

func init() {
for _, f := range allFuncs {
for _, f := range funcNames {
funcs[f] = true
}
for _, f := range aggregateFuncNames {
funcs[f] = true
aggregateFuncs[f] = true
}
// Special case for CAST, SAFE_CAST and EXTRACT
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
Expand All @@ -150,19 +155,9 @@ func init() {
funcArgParsers["GET_INTERNAL_SEQUENCE_STATE"] = sequenceArgParser
}

var allFuncs = []string{
var funcNames = []string{
// TODO: many more

// Aggregate functions.
"ANY_VALUE",
"ARRAY_AGG",
"AVG",
"BIT_XOR",
"COUNT",
"MAX",
"MIN",
"SUM",

// Cast functions.
"CAST",
"SAFE_CAST",
Expand Down Expand Up @@ -295,3 +290,28 @@ var allFuncs = []string{
// Utility functions.
"GENERATE_UUID",
}

var aggregateFuncNames = []string{
// Aggregate functions.
"ANY_VALUE",
"ARRAY_AGG",
"ARRAY_CONCAT_AGG",
"AVG",
"BIT_AND",
"BIT_OR",
"BIT_XOR",
"COUNT",
"COUNTIF",
"LOGICAL_AND",
"LOGICAL_OR",
"MAX",
"MIN",
"STRING_AGG",
"SUM",

// Statistical aggregate functions.
"STDDEV",
"STDDEV_SAMP",
"VAR_SAMP",
"VARIANCE",
}
63 changes: 63 additions & 0 deletions spanner/spansql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3566,6 +3566,65 @@ var sequenceArgParser = func(p *parser) (Expr, *parseError) {
return p.parseExpr()
}

func (p *parser) parseAggregateFunc() (Func, *parseError) {
tok := p.next()
if tok.err != nil {
return Func{}, tok.err
}
name := strings.ToUpper(tok.value)
if err := p.expect("("); err != nil {
return Func{}, err
}
var distinct bool
if p.eat("DISTINCT") {
distinct = true
}
args, err := p.parseExprList()
if err != nil {
return Func{}, err
}
var nullsHandling NullsHandling
if p.eat("IGNORE", "NULLS") {
nullsHandling = IgnoreNulls
} else if p.eat("RESPECT", "NULLS") {
nullsHandling = RespectNulls
}
var having *AggregateHaving
if p.eat("HAVING") {
tok := p.next()
if tok.err != nil {
return Func{}, tok.err
}
var cond AggregateHavingCondition
switch tok.value {
case "MAX":
cond = HavingMax
case "MIN":
cond = HavingMin
default:
return Func{}, p.errorf("got %q, want MAX or MIN", tok.value)
}
expr, err := p.parseExpr()
if err != nil {
return Func{}, err
}
having = &AggregateHaving{
Condition: cond,
Expr: expr,
}
}
if err := p.expect(")"); err != nil {
return Func{}, err
}
return Func{
Name: name,
Args: args,
Distinct: distinct,
NullsHandling: nullsHandling,
Having: having,
}, nil
}

/*
Expressions

Expand Down Expand Up @@ -3918,6 +3977,10 @@ func (p *parser) parseLit() (Expr, *parseError) {
// this is a function invocation.
// The `funcs` map is keyed by upper case strings.
if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") {
if aggregateFuncs[name] {
p.back()
return p.parseAggregateFunc()
}
var list []Expr
var err *parseError
if f, ok := funcArgParsers[name]; ok {
Expand Down
7 changes: 7 additions & 0 deletions spanner/spansql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ func TestParseExpr(t *testing.T) {
{`GET_NEXT_SEQUENCE_VALUE(SEQUENCE MySequence)`, Func{Name: "GET_NEXT_SEQUENCE_VALUE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}},
{`GET_INTERNAL_SEQUENCE_STATE(SEQUENCE MySequence)`, Func{Name: "GET_INTERNAL_SEQUENCE_STATE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}},

// Aggregate Functions
{`COUNT(*)`, Func{Name: "COUNT", Args: []Expr{Star}}},
{`COUNTIF(DISTINCT cname)`, Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true}},
{`ARRAY_AGG(Foo IGNORE NULLS)`, Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls}},
{`ANY_VALUE(Foo HAVING MAX Bar)`, Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}},
{`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`, Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}},

// Conditional expressions
{
`CASE X WHEN 1 THEN "X" WHEN 2 THEN "Y" ELSE NULL END`,
Expand Down
20 changes: 20 additions & 0 deletions spanner/spansql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,27 @@ func (f Func) SQL() string { return buildSQL(f) }
func (f Func) addSQL(sb *strings.Builder) {
sb.WriteString(f.Name)
sb.WriteString("(")
if f.Distinct {
sb.WriteString("DISTINCT ")
}
addExprList(sb, f.Args, ", ")
switch f.NullsHandling {
case RespectNulls:
sb.WriteString(" RESPECT NULLS")
case IgnoreNulls:
sb.WriteString(" IGNORE NULLS")
}
if ah := f.Having; ah != nil {
sb.WriteString(" HAVING")
switch ah.Condition {
case HavingMax:
sb.WriteString(" MAX")
case HavingMin:
sb.WriteString(" MIN")
}
sb.WriteString(" ")
sb.WriteString(ah.Expr.SQL())
}
sb.WriteString(")")
}

Expand Down
25 changes: 25 additions & 0 deletions spanner/spansql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,31 @@ func TestSQL(t *testing.T) {
`SELECT SAFE_CAST(7 AS DATE)`,
reparseQuery,
},
{
Func{Name: "COUNT", Args: []Expr{Star}},
`COUNT(*)`,
reparseExpr,
},
{
Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true},
`COUNTIF(DISTINCT cname)`,
reparseExpr,
},
{
Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls},
`ARRAY_AGG(Foo IGNORE NULLS)`,
reparseExpr,
},
{
Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}},
`ANY_VALUE(Foo HAVING MAX Bar)`,
reparseExpr,
},
{
Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}},
`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`,
reparseExpr,
},
{
ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")},
`X NOT BETWEEN Y AND Z`,
Expand Down
27 changes: 26 additions & 1 deletion spanner/spansql/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,9 @@ type Func struct {
Name string // not ID
Args []Expr

// TODO: various functions permit as-expressions, which might warrant different types in here.
Distinct bool
NullsHandling NullsHandling
Having *AggregateHaving
}

func (Func) isBoolExpr() {} // possibly bool
Expand Down Expand Up @@ -804,6 +806,29 @@ type SequenceExpr struct {

func (SequenceExpr) isExpr() {}

// NullsHandling represents the method of dealing with NULL values in aggregate functions.
type NullsHandling int

const (
NullsHandlingUnspecified NullsHandling = iota
RespectNulls
IgnoreNulls
)

// AggregateHaving represents the HAVING clause specific to aggregate functions, restricting rows based on a maximal or minimal value.
type AggregateHaving struct {
Condition AggregateHavingCondition
Expr Expr
}

// AggregateHavingCondition represents the condition (MAX or MIN) for the AggregateHaving clause.
type AggregateHavingCondition int

const (
HavingMax AggregateHavingCondition = iota
HavingMin
)

// Paren represents a parenthesised expression.
type Paren struct {
Expr Expr
Expand Down