-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
244 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package sq | ||
|
||
import ( | ||
"github.com/userhubdev/sq/internal/builder" | ||
) | ||
|
||
func init() { | ||
builder.Register(WithBuilder{}, withData{}) | ||
} | ||
|
||
// withPart is a helper structure to describe the cte parts of a WITH clause. | ||
type withPart struct { | ||
alias string | ||
cte Sqlizer | ||
} | ||
|
||
// newWithPart creates a new withPart for a WITH clause. | ||
func newWithPart(alias string, cte Sqlizer) withPart { | ||
return withPart{alias: alias, cte: cte} | ||
} | ||
|
||
// withData holds all the data required to build a WITH clause. | ||
type withData struct { | ||
PlaceholderFormat PlaceholderFormat | ||
WithParts []withPart | ||
} | ||
|
||
// ToSql implements Sqlizer. | ||
func (d *withData) ToSql() (sqlStr string, args []any, err error) { | ||
if len(d.WithParts) == 0 { | ||
return "", nil, nil | ||
} | ||
|
||
sql := Builder{} | ||
|
||
sql.WriteString("WITH") | ||
|
||
for i, p := range d.WithParts { | ||
if i > 0 { | ||
sql.WriteString(", ") | ||
} else { | ||
sql.WriteString(" ") | ||
} | ||
sql.WriteString(p.alias) | ||
sql.WriteString(" AS (") | ||
sql.WriteSql(p.cte) | ||
sql.WriteString(")") | ||
} | ||
|
||
return sql.ToSql() | ||
} | ||
|
||
// WithBuilder builds a WITH clause. | ||
type WithBuilder builder.Builder | ||
|
||
// ToSql builds the query into a SQL string and bound args. | ||
func (b WithBuilder) ToSql() (string, []any, error) { | ||
data := builder.GetStruct(b).(withData) | ||
return data.ToSql() | ||
} | ||
|
||
// As adds a "... AS (...)" part to the WITH clause. | ||
func (b WithBuilder) As(alias string, sql Sqlizer) WithBuilder { | ||
return builder.Append(b, "WithParts", newWithPart(alias, sql)).(WithBuilder) | ||
} | ||
|
||
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the | ||
// WITH clause. | ||
func (b WithBuilder) PlaceholderFormat(f PlaceholderFormat) WithBuilder { | ||
return builder.Set(b, "PlaceholderFormat", f).(WithBuilder) | ||
} | ||
|
||
// Select starts a primary SELECT statement for the WITH clause. | ||
func (b WithBuilder) Select(columns ...string) SelectBuilder { | ||
data := builder.GetStruct(b).(withData) | ||
|
||
sql := StatementBuilder.Select(columns...) | ||
|
||
if len(data.WithParts) > 0 { | ||
sql = sql.PrefixExpr(&data) | ||
} | ||
|
||
return sql.PlaceholderFormat(data.PlaceholderFormat) | ||
} | ||
|
||
// Insert starts a primary INSERT statement for the WITH clause. | ||
func (b WithBuilder) Insert(into string) InsertBuilder { | ||
data := builder.GetStruct(b).(withData) | ||
|
||
sql := StatementBuilder.Insert(into) | ||
|
||
if len(data.WithParts) > 0 { | ||
sql = sql.PrefixExpr(&data) | ||
} | ||
|
||
return sql.PlaceholderFormat(data.PlaceholderFormat) | ||
|
||
} | ||
|
||
// Update starts a primary UPDATE statement for the WITH clause. | ||
func (b WithBuilder) Update(table string) UpdateBuilder { | ||
data := builder.GetStruct(b).(withData) | ||
|
||
sql := StatementBuilder.Update(table) | ||
|
||
if len(data.WithParts) > 0 { | ||
sql = sql.PrefixExpr(&data) | ||
} | ||
|
||
return sql.PlaceholderFormat(data.PlaceholderFormat) | ||
} | ||
|
||
// Delete starts a primary DELETE statement for the WITH clause. | ||
func (b WithBuilder) Delete(from string) DeleteBuilder { | ||
data := builder.GetStruct(b).(withData) | ||
|
||
sql := StatementBuilder.Delete(from) | ||
|
||
if len(data.WithParts) > 0 { | ||
sql = sql.PrefixExpr(&data) | ||
} | ||
|
||
return sql.PlaceholderFormat(data.PlaceholderFormat) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
package sq | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestWithBuilder_ToSql(t *testing.T) { | ||
withClause := With(). | ||
As("c", Select("a").From("b").Where("a = ?", 1)) | ||
|
||
b := Select("a").PrefixExpr(withClause).From("c") | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = ?) SELECT a FROM c", sql) | ||
require.Equal(t, []any{1}, args) | ||
} | ||
|
||
func TestWithBuilder_SelectNoCTE(t *testing.T) { | ||
b := With().Select("a").From("b").Where("a = ?", 1) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "SELECT a FROM b WHERE a = ?", sql) | ||
require.Equal(t, []any{1}, args) | ||
} | ||
|
||
func TestWithBuilder_SelectOneCTE(t *testing.T) { | ||
b := StatementBuilder.PlaceholderFormat(AtP).With(). | ||
As("c", Select("a").From("b").Where("a = ?", 1)). | ||
Select("a").From("c") | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) SELECT a FROM c", sql) | ||
require.Equal(t, []any{1}, args) | ||
} | ||
|
||
func TestWithBuilder_SelectMultipleCTE(t *testing.T) { | ||
b := StatementBuilder.PlaceholderFormat(AtP).With(). | ||
As("c", Select("a").From("b").Where("a > ?", 1)). | ||
As("d", Select("a").From("c").Where("a < ?", 100)). | ||
Select("a").From("d") | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "WITH "+ | ||
"c AS ( SELECT a FROM b WHERE a > @p1), "+ | ||
"d AS ( SELECT a FROM c WHERE a < @p2) "+ | ||
"SELECT a FROM d", sql) | ||
require.Equal(t, []any{1, 100}, args) | ||
} | ||
|
||
func TestWithBuilder_InsertNoCTE(t *testing.T) { | ||
b := With().Insert("b").SetMap(map[string]any{"a": 5}) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "INSERT INTO b (a) VALUES (?)", sql) | ||
require.Equal(t, []any{5}, args) | ||
} | ||
|
||
func TestWithBuilder_InsertOneCTE(t *testing.T) { | ||
b := StatementBuilder.PlaceholderFormat(AtP).With(). | ||
As("c", Select("a").From("b").Where("a = ?", 1)). | ||
Insert("d").Columns("v"). | ||
Select(Select("a").From("c")) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) INSERT INTO d (v) SELECT a FROM c", sql) | ||
require.Equal(t, []any{1}, args) | ||
} | ||
|
||
func TestWithBuilder_UpdateNoCTE(t *testing.T) { | ||
b := With().Update("b").SetMap(map[string]any{"a": 5}) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "UPDATE b SET a = ?", sql) | ||
require.Equal(t, []any{5}, args) | ||
} | ||
|
||
func TestWithBuilder_UpdateOneCTE(t *testing.T) { | ||
b := StatementBuilder.PlaceholderFormat(AtP).With(). | ||
As("c", Select("a").From("b").Where("a = ?", 1)). | ||
Update("d"). | ||
SetMap(map[string]any{"v": 5}). | ||
Where(ConcatExpr("a IN (", Select("a").From("c"), ")")) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) UPDATE d SET v = @p2 WHERE a IN (SELECT a FROM c)", sql) | ||
require.Equal(t, []any{1, 5}, args) | ||
} | ||
|
||
func TestWithBuilder_DeleteNoCTE(t *testing.T) { | ||
b := With().Delete("b").Where(Eq{"a": 5}) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "DELETE FROM b WHERE a = ?", sql) | ||
require.Equal(t, []any{5}, args) | ||
} | ||
|
||
func TestWithBuilder_DeleteOneCTE(t *testing.T) { | ||
b := StatementBuilder.PlaceholderFormat(AtP).With(). | ||
As("c", Select("a").From("b").Where("a = ?", 1)). | ||
Delete("d"). | ||
Where(ConcatExpr("a IN (", Select("a").From("c"), ")")) | ||
sql, args, err := b.ToSql() | ||
require.NoError(t, err) | ||
require.Equal(t, "WITH c AS ( SELECT a FROM b WHERE a = @p1) DELETE FROM d WHERE a IN (SELECT a FROM c)", sql) | ||
require.Equal(t, []any{1}, args) | ||
} |