From 66206cdb8984379808a99bf69634853444ec0b7c Mon Sep 17 00:00:00 2001 From: toga4 <81744248+toga4@users.noreply.github.com> Date: Sat, 26 Aug 2023 21:23:24 +0900 Subject: [PATCH] feat(spanner/spansql): add support for aggregate functions --- spanner/spansql/keywords.go | 44 +++++++++++++++++------- spanner/spansql/parser.go | 63 ++++++++++++++++++++++++++++++++++ spanner/spansql/parser_test.go | 7 ++++ spanner/spansql/sql.go | 20 +++++++++++ spanner/spansql/sql_test.go | 25 ++++++++++++++ spanner/spansql/types.go | 27 ++++++++++++++- 6 files changed, 173 insertions(+), 13 deletions(-) diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go index a66d24e21825..c81606b4d6cf 100644 --- a/spanner/spansql/keywords.go +++ b/spanner/spansql/keywords.go @@ -129,11 +129,16 @@ var keywords = map[string]bool{ // https://cloud.google.com/spanner/docs/functions-and-operators var funcs = make(map[string]bool) var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError)) +var aggregateFuncs = make(map[string]bool) func init() { - for _, f := range allFuncs { + for _, f := range funcNames { funcs[f] = true } + for _, f := range aggregateFuncNames { + funcs[f] = true + aggregateFuncs[f] = true + } // Special case for CAST, SAFE_CAST and EXTRACT funcArgParsers["CAST"] = typedArgParser funcArgParsers["SAFE_CAST"] = typedArgParser @@ -150,19 +155,9 @@ func init() { funcArgParsers["GET_INTERNAL_SEQUENCE_STATE"] = sequenceArgParser } -var allFuncs = []string{ +var funcNames = []string{ // TODO: many more - // Aggregate functions. - "ANY_VALUE", - "ARRAY_AGG", - "AVG", - "BIT_XOR", - "COUNT", - "MAX", - "MIN", - "SUM", - // Cast functions. "CAST", "SAFE_CAST", @@ -295,3 +290,28 @@ var allFuncs = []string{ // Utility functions. "GENERATE_UUID", } + +var aggregateFuncNames = []string{ + // Aggregate functions. + "ANY_VALUE", + "ARRAY_AGG", + "ARRAY_CONCAT_AGG", + "AVG", + "BIT_AND", + "BIT_OR", + "BIT_XOR", + "COUNT", + "COUNTIF", + "LOGICAL_AND", + "LOGICAL_OR", + "MAX", + "MIN", + "STRING_AGG", + "SUM", + + // Statistical aggregate functions. + "STDDEV", + "STDDEV_SAMP", + "VAR_SAMP", + "VARIANCE", +} diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index eebf39d644a2..81eebb3aeae3 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -3566,6 +3566,65 @@ var sequenceArgParser = func(p *parser) (Expr, *parseError) { return p.parseExpr() } +func (p *parser) parseAggregateFunc() (Func, *parseError) { + tok := p.next() + if tok.err != nil { + return Func{}, tok.err + } + name := strings.ToUpper(tok.value) + if err := p.expect("("); err != nil { + return Func{}, err + } + var distinct bool + if p.eat("DISTINCT") { + distinct = true + } + args, err := p.parseExprList() + if err != nil { + return Func{}, err + } + var nullsHandling NullsHandling + if p.eat("IGNORE", "NULLS") { + nullsHandling = IgnoreNulls + } else if p.eat("RESPECT", "NULLS") { + nullsHandling = RespectNulls + } + var having *AggregateHaving + if p.eat("HAVING") { + tok := p.next() + if tok.err != nil { + return Func{}, tok.err + } + var cond AggregateHavingCondition + switch tok.value { + case "MAX": + cond = HavingMax + case "MIN": + cond = HavingMin + default: + return Func{}, p.errorf("got %q, want MAX or MIN", tok.value) + } + expr, err := p.parseExpr() + if err != nil { + return Func{}, err + } + having = &AggregateHaving{ + Condition: cond, + Expr: expr, + } + } + if err := p.expect(")"); err != nil { + return Func{}, err + } + return Func{ + Name: name, + Args: args, + Distinct: distinct, + NullsHandling: nullsHandling, + Having: having, + }, nil +} + /* Expressions @@ -3918,6 +3977,10 @@ func (p *parser) parseLit() (Expr, *parseError) { // this is a function invocation. // The `funcs` map is keyed by upper case strings. if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") { + if aggregateFuncs[name] { + p.back() + return p.parseAggregateFunc() + } var list []Expr var err *parseError if f, ok := funcArgParsers[name]; ok { diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 81147369cd00..c185f4e86965 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -419,6 +419,13 @@ func TestParseExpr(t *testing.T) { {`GET_NEXT_SEQUENCE_VALUE(SEQUENCE MySequence)`, Func{Name: "GET_NEXT_SEQUENCE_VALUE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}}, {`GET_INTERNAL_SEQUENCE_STATE(SEQUENCE MySequence)`, Func{Name: "GET_INTERNAL_SEQUENCE_STATE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}}, + // Aggregate Functions + {`COUNT(*)`, Func{Name: "COUNT", Args: []Expr{Star}}}, + {`COUNTIF(DISTINCT cname)`, Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true}}, + {`ARRAY_AGG(Foo IGNORE NULLS)`, Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls}}, + {`ANY_VALUE(Foo HAVING MAX Bar)`, Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}}, + {`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`, Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}}, + // Conditional expressions { `CASE X WHEN 1 THEN "X" WHEN 2 THEN "Y" ELSE NULL END`, diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index fb0f9c8ce45c..2b853f1bd650 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -913,7 +913,27 @@ func (f Func) SQL() string { return buildSQL(f) } func (f Func) addSQL(sb *strings.Builder) { sb.WriteString(f.Name) sb.WriteString("(") + if f.Distinct { + sb.WriteString("DISTINCT ") + } addExprList(sb, f.Args, ", ") + switch f.NullsHandling { + case RespectNulls: + sb.WriteString(" RESPECT NULLS") + case IgnoreNulls: + sb.WriteString(" IGNORE NULLS") + } + if ah := f.Having; ah != nil { + sb.WriteString(" HAVING") + switch ah.Condition { + case HavingMax: + sb.WriteString(" MAX") + case HavingMin: + sb.WriteString(" MIN") + } + sb.WriteString(" ") + sb.WriteString(ah.Expr.SQL()) + } sb.WriteString(")") } diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index ea31ae0fd169..06a74947672c 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -970,6 +970,31 @@ func TestSQL(t *testing.T) { `SELECT SAFE_CAST(7 AS DATE)`, reparseQuery, }, + { + Func{Name: "COUNT", Args: []Expr{Star}}, + `COUNT(*)`, + reparseExpr, + }, + { + Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true}, + `COUNTIF(DISTINCT cname)`, + reparseExpr, + }, + { + Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls}, + `ARRAY_AGG(Foo IGNORE NULLS)`, + reparseExpr, + }, + { + Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}, + `ANY_VALUE(Foo HAVING MAX Bar)`, + reparseExpr, + }, + { + Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}, + `STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`, + reparseExpr, + }, { ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")}, `X NOT BETWEEN Y AND Z`, diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 91afef941602..ad089a82f83b 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -757,7 +757,9 @@ type Func struct { Name string // not ID Args []Expr - // TODO: various functions permit as-expressions, which might warrant different types in here. + Distinct bool + NullsHandling NullsHandling + Having *AggregateHaving } func (Func) isBoolExpr() {} // possibly bool @@ -804,6 +806,29 @@ type SequenceExpr struct { func (SequenceExpr) isExpr() {} +// NullsHandling represents the method of dealing with NULL values in aggregate functions. +type NullsHandling int + +const ( + NullsHandlingUnspecified NullsHandling = iota + RespectNulls + IgnoreNulls +) + +// AggregateHaving represents the HAVING clause specific to aggregate functions, restricting rows based on a maximal or minimal value. +type AggregateHaving struct { + Condition AggregateHavingCondition + Expr Expr +} + +// AggregateHavingCondition represents the condition (MAX or MIN) for the AggregateHaving clause. +type AggregateHavingCondition int + +const ( + HavingMax AggregateHavingCondition = iota + HavingMin +) + // Paren represents a parenthesised expression. type Paren struct { Expr Expr