Skip to content

Commit

Permalink
Switch to strings.Builder and expose SQL Builder (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
silas authored Mar 21, 2024
1 parent fbb7bf8 commit 13dcd67
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 221 deletions.
39 changes: 39 additions & 0 deletions builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package sq

import "strings"

// Builder is a helper that allows to write many Sqlizers one by one
// without constant checks for errors that may come from Sqlizer
type Builder struct {
strings.Builder
args []any
err error
}

// WriteSql converts Sqlizer to SQL strings and writes it to strings.Builder
func (b *Builder) WriteSql(item Sqlizer) {
if b.err != nil {
return
}

var str string
var args []any
str, args, b.err = nestedToSql(item)

if b.err != nil {
return
}

if b.Len() > 0 {
b.WriteByte(' ')
}
b.WriteString(str)
b.args = append(b.args, args...)
}

func (b *Builder) ToSql() (string, []any, error) {
if b.err != nil {
return "", nil, b.err
}
return b.String(), b.args, nil
}
66 changes: 66 additions & 0 deletions builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package sq

import (
"errors"
"testing"

"github.com/stretchr/testify/require"
)

func TestBuilder(t *testing.T) {
b := Builder{}
b.WriteSql(Raw("test"))
sql, args, err := b.ToSql()
require.Equal(t, "test", sql)
require.Empty(t, args)
require.NoError(t, err)

b = Builder{}
b.WriteSql(Raw("one"))
b.WriteSql(Raw("two"))
sql, args, err = b.ToSql()
require.Equal(t, "one two", sql)
require.Empty(t, args)
require.NoError(t, err)

b = Builder{}
b.WriteString("one")
b.WriteSql(Raw("two"))
sql, args, err = b.ToSql()
require.Equal(t, "one two", sql)
require.Empty(t, args)
require.NoError(t, err)

b = Builder{}
b.WriteSql(Expr("one = ?", 1))
sql, args, err = b.ToSql()
require.Equal(t, "one = ?", sql)
require.Len(t, args, 1)
require.Equal(t, 1, args[0])
require.NoError(t, err)

b = Builder{}
b.WriteSql(Expr("one = ?", 1))
b.WriteSql(Expr("AND two = ?", 2))
sql, args, err = b.ToSql()
require.Equal(t, "one = ? AND two = ?", sql)
require.Len(t, args, 2)
require.Equal(t, 1, args[0])
require.Equal(t, 2, args[1])
require.NoError(t, err)

b = Builder{}
b.WriteSql(Expr("one = ?", 1))
b.WriteSql(sqlizeError{err: errors.New("fail")})
b.WriteSql(Expr("two = ?", 2))
sql, args, err = b.ToSql()
require.Empty(t, sql)
require.Empty(t, args)
require.EqualError(t, err, "fail")
}

type sqlizeError struct {
err error
}

func (se sqlizeError) ToSql() (string, []any, error) { return "", nil, se.err }
44 changes: 6 additions & 38 deletions case.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sq

import (
"bytes"
"errors"

"github.com/userhubdev/sq/internal/builder"
Expand All @@ -11,37 +10,6 @@ func init() {
builder.Register(CaseBuilder{}, caseData{})
}

// sqlizerBuffer is a helper that allows to write many Sqlizers one by one
// without constant checks for errors that may come from Sqlizer
type sqlizerBuffer struct {
bytes.Buffer
args []any
err error
}

// WriteSql converts Sqlizer to SQL strings and writes it to buffer
func (b *sqlizerBuffer) WriteSql(item Sqlizer) {
if b.err != nil {
return
}

var str string
var args []any
str, args, b.err = nestedToSql(item)

if b.err != nil {
return
}

b.WriteString(str)
b.WriteByte(' ')
b.args = append(b.args, args...)
}

func (b *sqlizerBuffer) ToSql() (string, []any, error) {
return b.String(), b.args, b.err
}

// whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression
type whenPart struct {
when Sqlizer
Expand All @@ -67,26 +35,26 @@ func (d *caseData) ToSql() (sqlStr string, args []any, err error) {
return
}

sql := sqlizerBuffer{}
sql := Builder{}

sql.WriteString("CASE ")
sql.WriteString("CASE")
if d.What != nil {
sql.WriteSql(d.What)
}

for _, p := range d.WhenParts {
sql.WriteString("WHEN ")
sql.WriteString(" WHEN")
sql.WriteSql(p.when)
sql.WriteString("THEN ")
sql.WriteString(" THEN")
sql.WriteSql(p.then)
}

if d.Else != nil {
sql.WriteString("ELSE ")
sql.WriteString(" ELSE")
sql.WriteSql(d.Else)
}

sql.WriteString("END")
sql.WriteString(" END")

return sql.ToSql()
}
Expand Down
3 changes: 1 addition & 2 deletions delete.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sq

import (
"bytes"
"fmt"
"strings"

Expand All @@ -25,7 +24,7 @@ func (d *deleteData) ToSql() (sqlStr string, args []any, err error) {
return
}

sql := &bytes.Buffer{}
sql := &strings.Builder{}

if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
Expand Down
15 changes: 7 additions & 8 deletions expr.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sq

import (
"bytes"
"database/sql/driver"
"fmt"
"reflect"
Expand Down Expand Up @@ -40,7 +39,7 @@ func (e expr) ToSql() (sql string, args []any, err error) {
return e.sql, e.args, nil
}

buf := &bytes.Buffer{}
b := &strings.Builder{}
ap := e.args
sp := e.sql

Expand All @@ -55,20 +54,20 @@ func (e expr) ToSql() (sql string, args []any, err error) {
}
if len(sp) > i+1 && sp[i+1:i+2] == "?" {
// escaped "??"; append it and step past
buf.WriteString(sp[:i+2])
b.WriteString(sp[:i+2])
sp = sp[i+2:]
continue
}

if as, ok := ap[0].(Sqlizer); ok {
// sqlizer argument; expand it and append the result
isql, iargs, err = nestedToSql(as)
buf.WriteString(sp[:i])
buf.WriteString(isql)
b.WriteString(sp[:i])
b.WriteString(isql)
args = append(args, iargs...)
} else {
// normal argument; append it and the placeholder
buf.WriteString(sp[:i+1])
b.WriteString(sp[:i+1])
args = append(args, ap[0])
}

Expand All @@ -78,8 +77,8 @@ func (e expr) ToSql() (sql string, args []any, err error) {
}

// append the remaining sql and arguments
buf.WriteString(sp)
return buf.String(), append(args, ap...), err
b.WriteString(sp)
return b.String(), append(args, ap...), err
}

type concatExpr []any
Expand Down
3 changes: 1 addition & 2 deletions insert.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sq

import (
"bytes"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -33,7 +32,7 @@ func (d *insertData) ToSql() (sqlStr string, args []any, err error) {
return
}

sql := &bytes.Buffer{}
sql := &strings.Builder{}

if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
Expand Down
11 changes: 6 additions & 5 deletions internal/ps/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
package ps

import (
"bytes"
"fmt"
"strings"
)

// Any is a shorthand for Go's verbose interface{} type.
Expand Down Expand Up @@ -279,11 +279,12 @@ func (m *tree) Keys() []string {
// make it easier to display maps for debugging
func (m *tree) String() string {
keys := m.Keys()
buf := bytes.NewBufferString("{")
var b strings.Builder
b.WriteString("{")
for _, key := range keys {
val, _ := m.Lookup(key)
fmt.Fprintf(buf, "%s: %s, ", key, val)
fmt.Fprintf(&b, "%s: %s, ", key, val)
}
fmt.Fprintf(buf, "}\n")
return buf.String()
fmt.Fprintf(&b, "}\n")
return b.String()
}
15 changes: 7 additions & 8 deletions placeholder.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sq

import (
"bytes"
"fmt"
"strings"
)
Expand Down Expand Up @@ -86,7 +85,7 @@ func Placeholders(count int) string {
}

func replacePositionalPlaceholders(sql, prefix string) (string, error) {
buf := &bytes.Buffer{}
b := &strings.Builder{}
i := 0
for {
p := strings.Index(sql, "?")
Expand All @@ -95,20 +94,20 @@ func replacePositionalPlaceholders(sql, prefix string) (string, error) {
}

if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ?
buf.WriteString(sql[:p])
buf.WriteString("?")
b.WriteString(sql[:p])
b.WriteString("?")
if len(sql[p:]) == 1 {
break
}
sql = sql[p+2:]
} else {
i++
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "%s%d", prefix, i)
b.WriteString(sql[:p])
fmt.Fprintf(b, "%s%d", prefix, i)
sql = sql[p+1:]
}
}

buf.WriteString(sql)
return buf.String(), nil
b.WriteString(sql)
return b.String(), nil
}
3 changes: 1 addition & 2 deletions select.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sq

import (
"bytes"
"fmt"
"strings"

Expand Down Expand Up @@ -40,7 +39,7 @@ func (d *selectData) ToSqlRaw() (sqlStr string, args []any, err error) {
return
}

sql := &bytes.Buffer{}
sql := &strings.Builder{}

if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
Expand Down
Loading

0 comments on commit 13dcd67

Please sign in to comment.