Skip to content

Commit

Permalink
Add iff/iif function
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Feb 6, 2024
1 parent 2f81325 commit b3973f6
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 99 deletions.
268 changes: 169 additions & 99 deletions pql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package pql
import (
"fmt"
"strings"
"sync"

"github.com/runreveal/pql/parser"
)
Expand Down Expand Up @@ -388,107 +389,11 @@ func writeExpression(sb *strings.Builder, source string, x parser.Expr) error {
}
}
case *parser.CallExpr:
switch x.Func.Name {
case "not":
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("not(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("NOT (")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(")")
case "isnull":
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("isnull(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(") IS NULL")
case "isnotnull":
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("isnotnull(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(") IS NOT NULL")
case "strcat":
if len(x.Args) == 0 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("strcat(x) takes least one argument"),
}
}
sb.WriteString("(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(")")
for _, arg := range x.Args[1:] {
sb.WriteString(" || (")
if err := writeExpression(sb, source, arg); err != nil {
return err
}
sb.WriteString(")")
}
case "count":
if len(x.Args) != 0 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("count() takes no arguments (got %d)", len(x.Args)),
}
}
sb.WriteString("count(*)")
case "countif":
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("countif(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("sum(CASE WHEN coalesce(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
if f := initKnownFunctions()[x.Func.Name]; f != nil {
if err := f(sb, source, x); err != nil {
return err
}
sb.WriteString(", FALSE) THEN 1 ELSE 0 END)")
default:
} else {
sb.WriteString(x.Func.Name)
sb.WriteString("(")
for i, arg := range x.Args {
Expand All @@ -507,6 +412,171 @@ func writeExpression(sb *strings.Builder, source string, x parser.Expr) error {
return nil
}

var knownFunctions struct {
init sync.Once
m map[string]func(sb *strings.Builder, source string, x *parser.CallExpr) error
}

func initKnownFunctions() map[string]func(sb *strings.Builder, source string, x *parser.CallExpr) error {
knownFunctions.init.Do(func() {
knownFunctions.m = map[string]func(sb *strings.Builder, source string, x *parser.CallExpr) error{
"not": writeNotFunction,
"isnull": writeIsNullFunction,
"isnotnull": writeIsNotNullFunction,
"strcat": writeStrcatFunction,
"count": writeCountFunction,
"countif": writeCountIfFunction,
"iff": writeIfFunction,
"iif": writeIfFunction,
}
})
return knownFunctions.m
}

func writeNotFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("not(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("NOT (")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(")")
return nil
}

func writeIsNullFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("isnull(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(") IS NULL")
return nil
}

func writeIsNotNullFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("isnotnull(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(") IS NOT NULL")
return nil
}

func writeStrcatFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) == 0 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("strcat(x) takes least one argument"),
}
}
sb.WriteString("(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(")")
for _, arg := range x.Args[1:] {
sb.WriteString(" || (")
if err := writeExpression(sb, source, arg); err != nil {
return err
}
sb.WriteString(")")
}
return nil
}

func writeCountFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) != 0 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("count() takes no arguments (got %d)", len(x.Args)),
}
}
sb.WriteString("count(*)")
return nil
}

func writeCountIfFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) != 1 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("countif(x) takes a single argument (got %d)", len(x.Args)),
}
}
sb.WriteString("sum(CASE WHEN coalesce(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(", FALSE) THEN 1 ELSE 0 END)")
return nil
}

func writeIfFunction(sb *strings.Builder, source string, x *parser.CallExpr) error {
if len(x.Args) != 3 {
return &compileError{
source: source,
span: parser.Span{
Start: x.Lparen.End,
End: x.Rparen.Start,
},
err: fmt.Errorf("%s(if, then, else) takes 3 arguments (got %d)", x.Func.Name, len(x.Args)),
}
}
sb.WriteString("CASE WHEN coalesce(")
if err := writeExpression(sb, source, x.Args[0]); err != nil {
return err
}
sb.WriteString(", FALSE) THEN ")
if err := writeExpression(sb, source, x.Args[1]); err != nil {
return err
}
sb.WriteString(" ELSE ")
if err := writeExpression(sb, source, x.Args[2]); err != nil {
return err
}
sb.WriteString(" END")
return nil
}

func quoteSQLString(sb *strings.Builder, s string) {
sb.WriteString("'")
for _, b := range []byte(s) {
Expand Down
3 changes: 3 additions & 0 deletions testdata/Goldens/If/input.pql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SourceFiles
| sort by LineCount desc, FileName asc
| project FileName, Size = iff(LineCount >= 1000, "Large", "Smol")
10 changes: 10 additions & 0 deletions testdata/Goldens/If/output.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
parser_test.go,Large
parser.go,Smol
lex.go,Smol
lex_test.go,Smol
pql.go,Smol
ast.go,Smol
clickhouse_test.go,Smol
golden_test.go,Smol
tokenkind_string.go,Smol
pql_test.go,Smol
2 changes: 2 additions & 0 deletions testdata/Goldens/If/output.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
WITH "subquery0" AS (SELECT * FROM "SourceFiles" ORDER BY "LineCount" DESC NULLS LAST, "FileName" ASC NULLS FIRST)
SELECT "FileName" AS "FileName", CASE WHEN coalesce(("LineCount") >= (1000), FALSE) THEN 'Large' ELSE 'Smol' END AS "Size" FROM "subquery0";

0 comments on commit b3973f6

Please sign in to comment.