Skip to content

Commit

Permalink
[-] make sure WithRewrittenSQL() call is optional, fixes #167
Browse files Browse the repository at this point in the history
  • Loading branch information
pashagolub committed Nov 2, 2023
1 parent 5e0d622 commit 3cbc4cc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
56 changes: 49 additions & 7 deletions expectations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -250,21 +251,37 @@ func TestMissingWithArgs(t *testing.T) {
}
}

type user struct {
ID int64
name string
email pgtype.Text
}

func (u *user) RewriteQuery(_ context.Context, _ *pgx.Conn, sql string, _ []any) (newSQL string, newArgs []any, err error) {
switch sql {
case "INSERT":
return `INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`, []any{u.name, u.email}, nil
case "UPDATE":
return `UPDATE users SET username = $1, email = $2 WHERE id = $1`, []any{u.ID, u.name, u.email}, nil
case "DELETE":
return `DELETE FROM users WHERE id = $1`, []any{u.ID}, nil
}
return
}

func TestWithRewrittenSQL(t *testing.T) {
t.Parallel()
mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual))
a := assert.New(t)
a.NoError(err)

mock.ExpectQuery(`INSERT INTO users(username) VALUES (@user)`).
WithArgs(pgx.NamedArgs{"user": "John"}).
WithRewrittenSQL(`INSERT INTO users(username) VALUES ($1)`).
u := user{name: "John", email: pgtype.Text{String: "[email protected]", Valid: true}}
mock.ExpectQuery(`INSERT`).
WithArgs(&u).
WithRewrittenSQL(`INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`).
WillReturnRows()

_, err = mock.Query(context.Background(),
"INSERT INTO users(username) VALUES (@user)",
pgx.NamedArgs{"user": "John"},
)
_, err = mock.Query(context.Background(), "INSERT", &u)
a.NoError(err)
a.NoError(mock.ExpectationsWereMet())

Expand All @@ -280,3 +297,28 @@ func TestWithRewrittenSQL(t *testing.T) {
a.Error(err)
a.Error(mock.ExpectationsWereMet())
}

func TestQueryRewriter(t *testing.T) {
t.Parallel()
mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual))
a := assert.New(t)
a.NoError(err)

update := `UPDATE "user" SET email = @email, password = @password, updated_utc = @updated_utc WHERE id = @id`

mock.ExpectExec(update).WithArgs(pgx.NamedArgs{
"id": "mockUser.ID",
"email": "mockUser.Email",
"password": "mockUser.Password",
"updated_utc": AnyArg(),
}).WillReturnError(errPanic)

_, err = mock.Exec(context.Background(), update, pgx.NamedArgs{
"id": "mockUser.ID",
"email": "mockUser.Email",
"password": "mockUser.Password",
"updated_utc": time.Now().UTC(),
})
a.Error(err)
a.NoError(mock.ExpectationsWereMet())
}
8 changes: 3 additions & 5 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ func (c *pgxmock) ExpectationsWereMet() error {
func (c *pgxmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
e := &ExpectedQuery{}
e.expectSQL = expectedSQL
e.expectRewrittenSQL = expectedSQL
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -235,7 +234,6 @@ func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin {
func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec {
e := &ExpectedExec{}
e.expectSQL = expectedSQL
e.expectRewrittenSQL = expectedSQL
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -371,7 +369,7 @@ func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx,
}

func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.StatementDescription, error) {
ex, err := findExpectationFunc[*ExpectedPrepare](c, "Exec()", func(prepareExp *ExpectedPrepare) error {
ex, err := findExpectationFunc[*ExpectedPrepare](c, "Prepare()", func(prepareExp *ExpectedPrepare) error {
if err := c.queryMatcher.Match(prepareExp.expectSQL, query); err != nil {
return err
}
Expand Down Expand Up @@ -434,7 +432,7 @@ func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (p
}
if rewrittenSQL, err := queryExp.argsMatches(sql, args); err != nil {
return err
} else if rewrittenSQL != "" {
} else if rewrittenSQL != "" && queryExp.expectRewrittenSQL != "" {
if err := c.queryMatcher.Match(queryExp.expectRewrittenSQL, rewrittenSQL); err != nil {
return err
}
Expand Down Expand Up @@ -474,7 +472,7 @@ func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) (
}
if rewrittenSQL, err := execExp.argsMatches(query, args); err != nil {
return err
} else if rewrittenSQL != "" {
} else if rewrittenSQL != "" && execExp.expectRewrittenSQL != "" {
if err := c.queryMatcher.Match(execExp.expectRewrittenSQL, rewrittenSQL); err != nil {
//pgx support QueryRewriter for arguments, now we can check if the query was actually rewriten
return err
Expand Down

0 comments on commit 3cbc4cc

Please sign in to comment.