Skip to content

Commit

Permalink
Add WITH clause support (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
silas authored Jun 18, 2024
1 parent 13dcd67 commit 586bbf5
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 2 deletions.
4 changes: 2 additions & 2 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ func (b InsertBuilder) Options(options ...string) InsertBuilder {
}

// Into sets the INTO clause of the query.
func (b InsertBuilder) Into(from string) InsertBuilder {
return builder.Set(b, "Into", from).(InsertBuilder)
func (b InsertBuilder) Into(into string) InsertBuilder {
return builder.Set(b, "Into", into).(InsertBuilder)
}

// Columns adds insert columns to the query.
Expand Down
10 changes: 10 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ func (b StatementBuilderType) Delete(from string) DeleteBuilder {
return DeleteBuilder(b).From(from)
}

// With returns a WithBuilder for this StatementBuilderType.
func (b StatementBuilderType) With() WithBuilder {
return WithBuilder(b)
}

// PlaceholderFormat sets the PlaceholderFormat field for any child builders.
func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType {
return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType)
Expand Down Expand Up @@ -82,6 +87,11 @@ func Delete(from string) DeleteBuilder {
return StatementBuilder.Delete(from)
}

// With returns a new WithBuilder.
func With() WithBuilder {
return StatementBuilder.With()
}

// Case returns a new CaseBuilder
// "what" represents case value
func Case(what ...any) CaseBuilder {
Expand Down
124 changes: 124 additions & 0 deletions with.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package sq

import (
"github.com/userhubdev/sq/internal/builder"
)

func init() {
builder.Register(WithBuilder{}, withData{})
}

// withPart is a helper structure to describe the cte parts of a WITH clause.
type withPart struct {
alias string
cte Sqlizer
}

// newWithPart creates a new withPart for a WITH clause.
func newWithPart(alias string, cte Sqlizer) withPart {
return withPart{alias: alias, cte: cte}
}

// withData holds all the data required to build a WITH clause.
type withData struct {
PlaceholderFormat PlaceholderFormat
WithParts []withPart
}

// ToSql implements Sqlizer.
func (d *withData) ToSql() (sqlStr string, args []any, err error) {
if len(d.WithParts) == 0 {
return "", nil, nil
}

sql := Builder{}

sql.WriteString("WITH")

for i, p := range d.WithParts {
if i > 0 {
sql.WriteString(", ")
} else {
sql.WriteString(" ")
}
sql.WriteString(p.alias)
sql.WriteString(" AS (")
sql.WriteSql(p.cte)
sql.WriteString(")")
}

return sql.ToSql()
}

// WithBuilder builds a WITH clause.
type WithBuilder builder.Builder

// ToSql builds the query into a SQL string and bound args.
func (b WithBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(withData)
return data.ToSql()
}

// As adds a "... AS (...)" part to the WITH clause.
func (b WithBuilder) As(alias string, sql Sqlizer) WithBuilder {
return builder.Append(b, "WithParts", newWithPart(alias, sql)).(WithBuilder)
}

// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// WITH clause.
func (b WithBuilder) PlaceholderFormat(f PlaceholderFormat) WithBuilder {
return builder.Set(b, "PlaceholderFormat", f).(WithBuilder)
}

// Select starts a primary SELECT statement for the WITH clause.
func (b WithBuilder) Select(columns ...string) SelectBuilder {
data := builder.GetStruct(b).(withData)

sql := StatementBuilder.Select(columns...)

if len(data.WithParts) > 0 {
sql = sql.PrefixExpr(&data)
}

return sql.PlaceholderFormat(data.PlaceholderFormat)
}

// Insert starts a primary INSERT statement for the WITH clause.
func (b WithBuilder) Insert(into string) InsertBuilder {
data := builder.GetStruct(b).(withData)

sql := StatementBuilder.Insert(into)

if len(data.WithParts) > 0 {
sql = sql.PrefixExpr(&data)
}

return sql.PlaceholderFormat(data.PlaceholderFormat)

}

// Update starts a primary UPDATE statement for the WITH clause.
func (b WithBuilder) Update(table string) UpdateBuilder {
data := builder.GetStruct(b).(withData)

sql := StatementBuilder.Update(table)

if len(data.WithParts) > 0 {
sql = sql.PrefixExpr(&data)
}

return sql.PlaceholderFormat(data.PlaceholderFormat)
}

// Delete starts a primary DELETE statement for the WITH clause.
func (b WithBuilder) Delete(from string) DeleteBuilder {
data := builder.GetStruct(b).(withData)

sql := StatementBuilder.Delete(from)

if len(data.WithParts) > 0 {
sql = sql.PrefixExpr(&data)
}

return sql.PlaceholderFormat(data.PlaceholderFormat)
}
108 changes: 108 additions & 0 deletions with_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package sq

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestWithBuilder_ToSql(t *testing.T) {
withClause := With().
As("c", Select("a").From("b").Where("a = ?", 1))

b := Select("a").PrefixExpr(withClause).From("c")
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = ?) SELECT a FROM c", sql)
require.Equal(t, []any{1}, args)
}

func TestWithBuilder_SelectNoCTE(t *testing.T) {
b := With().Select("a").From("b").Where("a = ?", 1)
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT a FROM b WHERE a = ?", sql)
require.Equal(t, []any{1}, args)
}

func TestWithBuilder_SelectOneCTE(t *testing.T) {
b := StatementBuilder.PlaceholderFormat(AtP).With().
As("c", Select("a").From("b").Where("a = ?", 1)).
Select("a").From("c")
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) SELECT a FROM c", sql)
require.Equal(t, []any{1}, args)
}

func TestWithBuilder_SelectMultipleCTE(t *testing.T) {
b := StatementBuilder.PlaceholderFormat(AtP).With().
As("c", Select("a").From("b").Where("a > ?", 1)).
As("d", Select("a").From("c").Where("a < ?", 100)).
Select("a").From("d")
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "WITH "+
"c AS ( SELECT a FROM b WHERE a > @p1), "+
"d AS ( SELECT a FROM c WHERE a < @p2) "+
"SELECT a FROM d", sql)
require.Equal(t, []any{1, 100}, args)
}

func TestWithBuilder_InsertNoCTE(t *testing.T) {
b := With().Insert("b").SetMap(map[string]any{"a": 5})
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "INSERT INTO b (a) VALUES (?)", sql)
require.Equal(t, []any{5}, args)
}

func TestWithBuilder_InsertOneCTE(t *testing.T) {
b := StatementBuilder.PlaceholderFormat(AtP).With().
As("c", Select("a").From("b").Where("a = ?", 1)).
Insert("d").Columns("v").
Select(Select("a").From("c"))
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) INSERT INTO d (v) SELECT a FROM c", sql)
require.Equal(t, []any{1}, args)
}

func TestWithBuilder_UpdateNoCTE(t *testing.T) {
b := With().Update("b").SetMap(map[string]any{"a": 5})
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "UPDATE b SET a = ?", sql)
require.Equal(t, []any{5}, args)
}

func TestWithBuilder_UpdateOneCTE(t *testing.T) {
b := StatementBuilder.PlaceholderFormat(AtP).With().
As("c", Select("a").From("b").Where("a = ?", 1)).
Update("d").
SetMap(map[string]any{"v": 5}).
Where(ConcatExpr("a IN (", Select("a").From("c"), ")"))
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) UPDATE d SET v = @p2 WHERE a IN (SELECT a FROM c)", sql)
require.Equal(t, []any{1, 5}, args)
}

func TestWithBuilder_DeleteNoCTE(t *testing.T) {
b := With().Delete("b").Where(Eq{"a": 5})
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "DELETE FROM b WHERE a = ?", sql)
require.Equal(t, []any{5}, args)
}

func TestWithBuilder_DeleteOneCTE(t *testing.T) {
b := StatementBuilder.PlaceholderFormat(AtP).With().
As("c", Select("a").From("b").Where("a = ?", 1)).
Delete("d").
Where(ConcatExpr("a IN (", Select("a").From("c"), ")"))
sql, args, err := b.ToSql()
require.NoError(t, err)
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) DELETE FROM d WHERE a IN (SELECT a FROM c)", sql)
require.Equal(t, []any{1}, args)
}

0 comments on commit 586bbf5

Please sign in to comment.