Skip to content

Commit

Permalink
Issue #193
Browse files Browse the repository at this point in the history
* [ADDED] Support for CASE statements #193
  • Loading branch information
doug-martin committed Mar 20, 2020
1 parent 6b579ea commit a8f0c36
Show file tree
Hide file tree
Showing 9 changed files with 414 additions and 0 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# v9.8.0

* [ADDED] Support for ANY and ALL operators. [#196](https://github.com/doug-martin/goqu/issues/196)
* [ADDED] Support for CASE statements [#193](https://github.com/doug-martin/goqu/issues/193)

# v9.7.1

Expand Down
75 changes: 75 additions & 0 deletions exp/case.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package exp

type (
caseElse struct {
result interface{}
}
caseWhen struct {
caseElse
condition interface{}
}
caseExpression struct {
value interface{}
whens []CaseWhen
elseCondition CaseElse
}
)

func NewCaseElse(result interface{}) CaseElse {
return caseElse{result: result}
}

func (ce caseElse) Result() interface{} {
return ce.result
}

func NewCaseWhen(condition, result interface{}) CaseWhen {
return caseWhen{caseElse: caseElse{result: result}, condition: condition}
}

func (cw caseWhen) Condition() interface{} {
return cw.condition
}

func NewCaseExpression() CaseExpression {
return caseExpression{value: nil, whens: []CaseWhen{}, elseCondition: nil}
}

func (c caseExpression) Expression() Expression {
return c
}

func (c caseExpression) Clone() Expression {
return caseExpression{value: c.value, whens: c.whens, elseCondition: c.elseCondition}
}

func (c caseExpression) As(alias interface{}) AliasedExpression {
return aliased(c, alias)
}

func (c caseExpression) GetValue() interface{} {
return c.value
}

func (c caseExpression) GetWhens() []CaseWhen {
return c.whens
}

func (c caseExpression) GetElse() CaseElse {
return c.elseCondition
}

func (c caseExpression) Value(value interface{}) CaseExpression {
c.value = value
return c
}

func (c caseExpression) When(condition, result interface{}) CaseExpression {
c.whens = append(c.whens, NewCaseWhen(condition, result))
return c
}

func (c caseExpression) Else(result interface{}) CaseExpression {
c.elseCondition = NewCaseElse(result)
return c
}
88 changes: 88 additions & 0 deletions exp/case_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package exp

import (
"testing"

"github.com/stretchr/testify/suite"
)

type caseExpressionSuite struct {
suite.Suite
}

func TestCaseExpressionSuite(t *testing.T) {
suite.Run(t, &caseExpressionSuite{})
}

func (ces *caseExpressionSuite) TestClone() {
ce := NewCaseExpression()
ces.Equal(ce, ce.Clone())
}

func (ces *caseExpressionSuite) TestExpression() {
ce := NewCaseExpression()
ces.Equal(ce, ce.Expression())
}

func (ces *caseExpressionSuite) TestAs() {
ce := NewCaseExpression()
ces.Equal(aliased(ce, "a"), ce.As("a"))
}

func (ces *caseExpressionSuite) TestValue() {
ce := NewCaseExpression()
ces.Nil(ce.GetValue())

ce = NewCaseExpression().Value(NewIdentifierExpression("", "", "a"))
ces.Equal(NewIdentifierExpression("", "", "a"), ce.GetValue())
}

func (ces *caseExpressionSuite) TestWhen() {
condition1 := NewIdentifierExpression("", "", "a").Eq(10)
condition2 := NewIdentifierExpression("", "", "b").Eq(20)
ce := NewCaseExpression()
ces.Equal([]CaseWhen{
NewCaseWhen(condition1, "a"),
NewCaseWhen(condition2, "b"),
}, ce.When(condition1, "a").When(condition2, "b").GetWhens())

ces.Empty(ce.GetWhens())
}

func (ces *caseExpressionSuite) TestElse() {
ce := NewCaseExpression()
ces.Equal(NewCaseElse("a"), ce.Else("a").GetElse())

ces.Nil(ce.GetElse())
}

type caseWhenSuite struct {
suite.Suite
}

func TestCaseWhenSuite(t *testing.T) {
suite.Run(t, &caseWhenSuite{})
}

func (cws *caseWhenSuite) TestCondition() {
ce := NewCaseWhen(true, false)
cws.Equal(true, ce.Condition())
}

func (cws *caseWhenSuite) TestResult() {
ce := NewCaseWhen(true, false)
cws.Equal(false, ce.Result())
}

type caseElseSuite struct {
suite.Suite
}

func TestCaseElseSuite(t *testing.T) {
suite.Run(t, &caseElseSuite{})
}

func (ces *caseElseSuite) TestResult() {
ce := NewCaseElse(false)
ces.Equal(false, ce.Result())
}
17 changes: 17 additions & 0 deletions exp/exp.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,23 @@ type (
PartitionBy(cols ...interface{}) WindowExpression
OrderBy(cols ...interface{}) WindowExpression
}
CaseElse interface {
Result() interface{}
}
CaseWhen interface {
Condition() interface{}
Result() interface{}
}
CaseExpression interface {
Expression
Aliaseable
GetValue() interface{}
GetWhens() []CaseWhen
GetElse() CaseElse
Value(val interface{}) CaseExpression
When(condition, result interface{}) CaseExpression
Else(result interface{}) CaseExpression
}
)

const (
Expand Down
4 changes: 4 additions & 0 deletions expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,7 @@ func Any(val interface{}) exp.SQLFunctionExpression {
func All(val interface{}) exp.SQLFunctionExpression {
return Func("ALL ", val)
}

func Case() exp.CaseExpression {
return exp.NewCaseExpression()
}
80 changes: 80 additions & 0 deletions expressions_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1821,3 +1821,83 @@ func ExampleAll() {
// SELECT * FROM "test" WHERE ("id" = ALL ((SELECT "test_id" FROM "other"))) []
// SELECT * FROM "test" WHERE ("id" = ALL ((SELECT "test_id" FROM "other"))) []
}

func ExampleCase_search() {
ds := goqu.From("test").
Select(
goqu.C("col"),
goqu.Case().
When(goqu.C("col").Gt(0), true).
When(goqu.C("col").Lte(0), false).
As("is_gt_zero"),
)
sql, args, _ := ds.ToSQL()
fmt.Println(sql, args)

sql, args, _ = ds.Prepared(true).ToSQL()
fmt.Println(sql, args)
// Output:
// SELECT "col", CASE WHEN ("col" > 0) THEN TRUE WHEN ("col" <= 0) THEN FALSE END AS "is_gt_zero" FROM "test" []
// SELECT "col", CASE WHEN ("col" > ?) THEN ? WHEN ("col" <= ?) THEN ? END AS "is_gt_zero" FROM "test" [0 true 0 false]
}

func ExampleCase_searchElse() {
ds := goqu.From("test").
Select(
goqu.C("col"),
goqu.Case().
When(goqu.C("col").Gt(10), "Gt 10").
When(goqu.C("col").Gt(20), "Gt 20").
Else("Bad Val").
As("str_val"),
)
sql, args, _ := ds.ToSQL()
fmt.Println(sql, args)

sql, args, _ = ds.Prepared(true).ToSQL()
fmt.Println(sql, args)
// Output:
// SELECT "col", CASE WHEN ("col" > 10) THEN 'Gt 10' WHEN ("col" > 20) THEN 'Gt 20' ELSE 'Bad Val' END AS "str_val" FROM "test" []
// SELECT "col", CASE WHEN ("col" > ?) THEN ? WHEN ("col" > ?) THEN ? ELSE ? END AS "str_val" FROM "test" [10 Gt 10 20 Gt 20 Bad Val]
}

func ExampleCase_value() {
ds := goqu.From("test").
Select(
goqu.C("col"),
goqu.Case().
Value(goqu.C("str")).
When("foo", "FOO").
When("bar", "BAR").
As("foo_bar_upper"),
)
sql, args, _ := ds.ToSQL()
fmt.Println(sql, args)

sql, args, _ = ds.Prepared(true).ToSQL()
fmt.Println(sql, args)
// Output:
// SELECT "col", CASE "str" WHEN 'foo' THEN 'FOO' WHEN 'bar' THEN 'BAR' END AS "foo_bar_upper" FROM "test" []
// SELECT "col", CASE "str" WHEN ? THEN ? WHEN ? THEN ? END AS "foo_bar_upper" FROM "test" [foo FOO bar BAR]
}

func ExampleCase_valueElse() {
ds := goqu.From("test").
Select(
goqu.C("col"),
goqu.Case().
Value(goqu.C("str")).
When("foo", "FOO").
When("bar", "BAR").
Else("Baz").
As("foo_bar_upper"),
)
sql, args, _ := ds.ToSQL()
fmt.Println(sql, args)

sql, args, _ = ds.Prepared(true).ToSQL()
fmt.Println(sql, args)
// Output:
// SELECT "col", CASE "str" WHEN 'foo' THEN 'FOO' WHEN 'bar' THEN 'BAR' ELSE 'Baz' END AS "foo_bar_upper" FROM "test" []
// SELECT "col", CASE "str" WHEN ? THEN ? WHEN ? THEN ? ELSE ? END AS "foo_bar_upper" FROM "test" [foo FOO bar BAR Baz]
}
30 changes: 30 additions & 0 deletions sqlgen/expression_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
`a empty identifier was encountered, please specify a "schema", "table" or "column"`,
)
errUnexpectedNamedWindow = errors.New(`unexpected named window function`)
errEmptyCaseWhens = errors.New(`when conditions not found for case statement`)
)

func errUnsupportedExpressionType(e exp.Expression) error {
Expand Down Expand Up @@ -180,6 +181,8 @@ func (esg *expressionSQLGenerator) expressionSQL(b sb.SQLBuilder, expression exp
esg.commonTableExpressionSQL(b, e)
case exp.CompoundExpression:
esg.compoundExpressionSQL(b, e)
case exp.CaseExpression:
esg.caseExpressionSQL(b, e)
case exp.Ex:
esg.expressionMapSQL(b, e)
case exp.ExOr:
Expand Down Expand Up @@ -632,6 +635,33 @@ func (esg *expressionSQLGenerator) compoundExpressionSQL(b sb.SQLBuilder, compou
}
}

// Generates SQL for a CaseExpression
func (esg *expressionSQLGenerator) caseExpressionSQL(b sb.SQLBuilder, caseExpression exp.CaseExpression) {
caseVal := caseExpression.GetValue()
whens := caseExpression.GetWhens()
elseResult := caseExpression.GetElse()

if len(whens) == 0 {
b.SetError(errEmptyCaseWhens)
return
}
b.Write(esg.dialectOptions.CaseFragment)
if caseVal != nil {
esg.Generate(b, caseVal)
}
for _, when := range whens {
b.Write(esg.dialectOptions.WhenFragment)
esg.Generate(b, when.Condition())
b.Write(esg.dialectOptions.ThenFragment)
esg.Generate(b, when.Result())
}
if elseResult != nil {
b.Write(esg.dialectOptions.ElseFragment)
esg.Generate(b, elseResult.Result())
}
b.Write(esg.dialectOptions.EndFragment)
}

func (esg *expressionSQLGenerator) expressionMapSQL(b sb.SQLBuilder, ex exp.Ex) {
expressionList, err := ex.ToExpressions()
if err != nil {
Expand Down
Loading

0 comments on commit a8f0c36

Please sign in to comment.