Skip to content

Commit

Permalink
Fix join table filter condition generation (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
lafriks authored Sep 28, 2023
1 parent 5b67559 commit 59ff908
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 40 deletions.
10 changes: 8 additions & 2 deletions builder/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,16 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) {
to = join.To
)

jtable := join.Table
// If join table has alias use that for filter conditions
if i := strings.Index(strings.ToLower(jtable), " as "); i > -1 {
jtable = jtable[i+4:]
}

// TODO: move this to core functionality, and infer join condition using assoc data.
if join.Arguments == nil && (join.From == "" || join.To == "") {
from = table + "." + strings.TrimSuffix(join.Table, "s") + "_id"
to = join.Table + ".id"
to = jtable + ".id"
}

buffer.WriteByte(' ')
Expand All @@ -127,7 +133,7 @@ func (q Query) WriteJoin(buffer *Buffer, table string, joins []rel.JoinQuery) {
buffer.WriteEscape(to)
if !join.Filter.None() {
buffer.WriteString(" AND ")
q.Filter.Write(buffer, join.Table, join.Filter, q)
q.Filter.Write(buffer, jtable, join.Filter, q)
}
}

Expand Down
58 changes: 20 additions & 38 deletions builder/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ import (
)

func BenchmarkQuery_Build(b *testing.B) {
var (
queryBuilder = Query{
BufferFactory: BufferFactory{ArgumentPlaceholder: "?", Quoter: Quote{IDPrefix: "`", IDSuffix: "`", IDSuffixEscapeChar: "`", ValueQuote: "'", ValueQuoteEscapeChar: "'"}},
Filter: Filter{},
}
)
queryBuilder := Query{
BufferFactory: BufferFactory{ArgumentPlaceholder: "?", Quoter: Quote{IDPrefix: "`", IDSuffix: "`", IDSuffixEscapeChar: "`", ValueQuote: "'", ValueQuoteEscapeChar: "'"}},
Filter: Filter{},
}

for n := 0; n < b.N; n++ {
query := rel.From("users").
Expand Down Expand Up @@ -83,9 +81,13 @@ func TestQuery_Build(t *testing.T) {
query: query.JoinWith("INNER JOIN", "transactions", "transactions.id", "users.transaction_id"),
},
{
result: "SELECT `users`.* FROM `users` INNER JOIN `transactions` ON `transactions`.`id`=`users`.`transaction_id` AND (`transactions`.`status`=? AND `users`.`type`=?) WHERE `users`.`id`=?;",
result: "SELECT `users`.* FROM `users` INNER JOIN `transactions` ON `transactions`.`id`=`users`.`transaction_id`;",
query: query.JoinWith("INNER JOIN", "transactions", "transactions.id", "users.transaction_id"),
},
{
result: "SELECT `users`.* FROM `users` INNER JOIN `transactions` AS `t` ON `transactions`.`id`=`users`.`transaction_id` AND (`t`.`status`=? AND `users`.`type`=?) WHERE `users`.`id`=?;",
args: []any{1, 2, 10},
query: query.JoinWith("INNER JOIN", "transactions", "transactions.id", "users.transaction_id", rel.Eq("status", 1), rel.Eq("users.type", 2)).Where(rel.Eq("id", 10)),
query: query.JoinWith("INNER JOIN", "transactions as t", "transactions.id", "users.transaction_id", rel.Eq("t.status", 1), rel.Eq("users.type", 2)).Where(rel.Eq("id", 10)),
},
{
result: "SELECT `users`.* FROM `users` ORDER BY `users`.`created_at` ASC;",
Expand All @@ -107,9 +109,7 @@ func TestQuery_Build(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
result, args = queryBuilder.Build(test.query)
)
result, args := queryBuilder.Build(test.query)

assert.Equal(t, test.result, result)
assert.Equal(t, test.args, args)
Expand Down Expand Up @@ -189,9 +189,7 @@ func TestQuery_Build_ordinal(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
qs, args = queryBuilder.Build(test.query)
)
qs, args := queryBuilder.Build(test.query)

assert.Equal(t, test.result, qs)
assert.Equal(t, test.args, args)
Expand Down Expand Up @@ -264,9 +262,7 @@ func TestQuery_WriteSelect(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteSelect(&buffer, test.table, test.selectQuery)
assert.Equal(t, test.result, buffer.String())
Expand Down Expand Up @@ -326,9 +322,7 @@ func TestQuery_WriteJoin(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteJoin(&buffer, "transactions", rel.Build("", test.query).JoinQuery)

Expand Down Expand Up @@ -364,9 +358,7 @@ func TestQuery_WriteWhere(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteWhere(&buffer, test.table, test.filter)

Expand Down Expand Up @@ -402,9 +394,7 @@ func TestQuery_WriteWhere_ordinal(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteWhere(&buffer, test.table, test.filter)

Expand Down Expand Up @@ -451,9 +441,7 @@ func TestQuery_WriteWhere_SubQuery(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteWhere(&buffer, test.table, test.filter)

Expand Down Expand Up @@ -502,9 +490,7 @@ func TestQuery_WriteWhere_SubQuery_ordinal(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteWhere(&buffer, test.table, test.filter)

Expand Down Expand Up @@ -569,9 +555,7 @@ func TestQuery_WriteHaving(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteHaving(&buffer, test.table, test.filter)

Expand Down Expand Up @@ -607,9 +591,7 @@ func TestQuery_WriteHaving_ordinal(t *testing.T) {

for _, test := range tests {
t.Run(test.result, func(t *testing.T) {
var (
buffer = bufferFactory.Create()
)
buffer := bufferFactory.Create()

queryBuilder.WriteHaving(&buffer, test.table, test.filter)

Expand Down

0 comments on commit 59ff908

Please sign in to comment.