Skip to content

Commit

Permalink
Break apart expressions into subqueries for transpilation
Browse files Browse the repository at this point in the history
Fixes #16
  • Loading branch information
zombiezen committed Feb 5, 2024
1 parent 0156667 commit 9d7b880
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 35 deletions.
140 changes: 107 additions & 33 deletions pql.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,132 @@ func Compile(source string) (string, error) {
return "", err
}

dataSource, err := dataSourceSQL(expr.Source)
subqueries, err := splitQueries(expr)
if err != nil {
return "", err
}

switch {
case len(expr.Operators) == 0:
return "SELECT * FROM " + dataSource + ";", nil
case len(expr.Operators) == 1:
switch op := expr.Operators[0].(type) {
case *parser.CountOperator:
return "SELECT COUNT(*) FROM " + dataSource + ";", nil
sb := new(strings.Builder)
ctes := subqueries[:len(subqueries)-1]
query := subqueries[len(subqueries)-1]
if len(ctes) > 0 {
sb.WriteString("WITH ")
for i, sub := range ctes {
quoteIdentifier(sb, sub.name)
sb.WriteString(" AS (")
sub.write(sb)
sb.WriteString(")")
if i < len(ctes)-1 {
sb.WriteString(",")
}
sb.WriteString("\n")
}
}
query.write(sb)
sb.WriteString(";")
return sb.String(), nil
}

type subquery struct {
name string
sourceSQL string

op parser.TabularOperator
sort *parser.SortOperator
take *parser.TakeOperator
}

func splitQueries(expr *parser.TabularExpr) ([]*subquery, error) {
var subqueries []*subquery
var lastSubquery *subquery
for i := 0; i < len(expr.Operators); i++ {
switch op := expr.Operators[i].(type) {
case *parser.SortOperator:
if lastSubquery != nil && lastSubquery.sort == nil && lastSubquery.take == nil {
lastSubquery.sort = op
} else {
lastSubquery = &subquery{
sort: op,
}
subqueries = append(subqueries, lastSubquery)
}
case *parser.TakeOperator:
if lastSubquery != nil && lastSubquery.take == nil {
lastSubquery.take = op
} else {
lastSubquery = &subquery{
take: op,
}
subqueries = append(subqueries, lastSubquery)
}
default:
return "", fmt.Errorf("unsupported operator %T", op)
lastSubquery = &subquery{
op: op,
}
subqueries = append(subqueries, lastSubquery)
}
}

if len(subqueries) == 0 {
subqueries = append(subqueries, new(subquery))
}
buf := new(strings.Builder)
for i, sub := range subqueries {
if i == 0 {
var err error
sub.sourceSQL, err = dataSourceSQL(expr.Source)
if err != nil {
return nil, err
}
} else {
buf.Reset()
quoteIdentifier(buf, subqueries[i-1].name)
sub.sourceSQL = buf.String()
}

if i < len(subqueries)-1 {
sub.name = fmt.Sprintf("subquery%d", i)
}
}

return subqueries, nil
}

func (sub *subquery) write(sb *strings.Builder) {
switch op := sub.op.(type) {
case nil:
sb.WriteString("SELECT * FROM ")
sb.WriteString(sub.sourceSQL)
case *parser.CountOperator:
sb.WriteString("SELECT COUNT(*) FROM ")
sb.WriteString(sub.sourceSQL)
default:
return "", fmt.Errorf("only one operator implemented")
fmt.Fprintf(sb, "/* unsupported operator %T */", op)
}
}

func dataSourceSQL(src parser.TabularDataSource) (string, error) {
switch src := src.(type) {
case *parser.TableRef:
return quoteIdentifier(src.Table.Name), nil
sb := new(strings.Builder)
quoteIdentifier(sb, src.Table.Name)
return sb.String(), nil
default:
return "", fmt.Errorf("unhandled data source %T", src)
}
}

func quoteIdentifier(name string) string {
if sqlIdentifierNeedsQuote(name) {
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
}
return name
}
func quoteIdentifier(sb *strings.Builder, name string) {
const quoteEscape = `""`
sb.Grow(len(name) + strings.Count(name, `"`)*(len(quoteEscape)-1) + len(`""`))

func sqlIdentifierNeedsQuote(name string) bool {
if name == "" || !isAlpha(rune(name[0])) && name[0] != '_' {
return true
}
for i := 1; i < len(name); i++ {
if !isAlpha(rune(name[i])) && !isDigit(rune(name[i])) && name[i] != '_' {
return true
sb.WriteString(`"`)
for _, b := range []byte(name) {
if b == '"' {
sb.WriteString(quoteEscape)
} else {
sb.WriteByte(b)
}
}
return false
}

func isAlpha(c rune) bool {
return 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z'
}

func isDigit(c rune) bool {
return '0' <= c && c <= '9'
sb.WriteString(`"`)
}
2 changes: 1 addition & 1 deletion testdata/Goldens/Count/output.sql
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT COUNT(*) FROM StormEvents;
SELECT COUNT(*) FROM "StormEvents";
6 changes: 6 additions & 0 deletions testdata/Goldens/DoubleCount/StormEvents.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
State,EventType,DamageProperty
ATLANTIC SOUTH,Waterspout,0
FLORIDA,Heavy Rain,0
FLORIDA,Tornado,6200000
GEORGIA,Thunderstorm Wind,2000
MISSISSIPPI,Thunderstorm Wind,20000
3 changes: 3 additions & 0 deletions testdata/Goldens/DoubleCount/input.pql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
StormEvents
| count
| count
1 change: 1 addition & 0 deletions testdata/Goldens/DoubleCount/output.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
2 changes: 2 additions & 0 deletions testdata/Goldens/DoubleCount/output.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
WITH "subquery0" AS (SELECT COUNT(*) FROM "StormEvents")
SELECT COUNT(*) FROM "subquery0";
2 changes: 1 addition & 1 deletion testdata/Goldens/Table/output.sql
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT * FROM StormEvents;
SELECT * FROM "StormEvents";

0 comments on commit 9d7b880

Please sign in to comment.