From 90da7fd6f23cdfe99092d488020bbfec99673961 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Fri, 11 Mar 2022 17:15:09 +0100 Subject: [PATCH] [+] improve test coverage --- argument_test.go | 23 +++++++++++++++++++++-- driver_test.go | 30 ++++++++++++++++++++++++++++++ pgxmock_test.go | 14 ++++++++++++++ rows.go | 44 +++++++++++++++++++++++++------------------- rows_test.go | 36 +++++++++++++++++++++++++++++++++--- 5 files changed, 123 insertions(+), 24 deletions(-) diff --git a/argument_test.go b/argument_test.go index d71e835..e760d93 100644 --- a/argument_test.go +++ b/argument_test.go @@ -20,7 +20,6 @@ func TestAnyTimeArgument(t *testing.T) { if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } - // defer db.Close() mock.ExpectExec("INSERT INTO users"). WithArgs("john", AnyTime{}). @@ -42,7 +41,6 @@ func TestByteSliceArgument(t *testing.T) { if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } - // defer db.Close() username := []byte("user") mock.ExpectExec("INSERT INTO users").WithArgs(username).WillReturnResult(NewResult("INSERT", 1)) @@ -56,3 +54,24 @@ func TestByteSliceArgument(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestAnyArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + mock.ExpectExec("INSERT INTO users"). + WithArgs("john", AnyArg()). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/driver_test.go b/driver_test.go index 9913796..489c9f0 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1,6 +1,7 @@ package pgxmock import ( + "context" "testing" ) @@ -16,4 +17,33 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { if mock == mock2 { t.Errorf("expected not the same mock instance, but it is the same") } + mock.Close(context.Background()) + mock2.Close(context.Background()) +} + +func TestPools(t *testing.T) { + mock, err := NewPool() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + mock2, err := NewPool() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + if mock == mock2 { + t.Errorf("expected not the same mock instance, but it is the same") + } + mock.Close() + mock2.Close() +} + +func TestAcquire(t *testing.T) { + mock, err := NewPool() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + _, err = mock.Acquire(context.Background()) + if err == nil { + t.Error("expected error, but got nil") + } } diff --git a/pgxmock_test.go b/pgxmock_test.go index 746c46d..0be7e59 100644 --- a/pgxmock_test.go +++ b/pgxmock_test.go @@ -1184,3 +1184,17 @@ func queryWithTimeout(t time.Duration, db pgxIface, query string, args ...interf return nil, fmt.Errorf("query timed out after %v", t) } } + +func TestCon(t *testing.T) { + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer mock.Close(context.Background()) + defer func() { + if r := recover(); r == nil { + t.Errorf("The Conn() did not panic") + } + }() + _ = mock.Conn() +} diff --git a/rows.go b/rows.go index 80df0bd..6ba3719 100644 --- a/rows.go +++ b/rows.go @@ -3,7 +3,6 @@ package pgxmock import ( "encoding/csv" "fmt" - "io" "reflect" "strings" @@ -34,7 +33,7 @@ func (rs *rowSets) Err() error { } func (rs *rowSets) CommandTag() pgconn.CommandTag { - return pgconn.CommandTag("") + return rs.sets[rs.pos].commandTag } func (rs *rowSets) FieldDescriptions() []pgproto3.FieldDescription { @@ -168,11 +167,12 @@ func rawBytes(col interface{}) (_ []byte, ok bool) { // Rows is a mocked collection of rows to // return for Query result type Rows struct { - defs []pgproto3.FieldDescription - rows [][]interface{} - pos int - nextErr map[int]error - closeErr error + commandTag pgconn.CommandTag + defs []pgproto3.FieldDescription + rows [][]interface{} + pos int + nextErr map[int]error + closeErr error } // NewRows allows Rows to be created from a @@ -225,6 +225,12 @@ func (r *Rows) AddRow(values ...interface{}) *Rows { return r } +// AddCommandTag will add a command tag to the result set +func (r *Rows) AddCommandTag(tag pgconn.CommandTag) *Rows { + r.commandTag = tag + return r +} + // FromCSVString build rows from csv string. // return the same instance to perform subsequent actions. // Note that the number of values must match the number @@ -248,20 +254,20 @@ func (r *Rows) FromCSVString(s string) *Rows { return r } -// Implement the "RowsNextResultSet" interface -func (rs *rowSets) HasNextResultSet() bool { - return rs.pos+1 < len(rs.sets) -} +// // Implement the "RowsNextResultSet" interface +// func (rs *rowSets) HasNextResultSet() bool { +// return rs.pos+1 < len(rs.sets) +// } -// Implement the "RowsNextResultSet" interface -func (rs *rowSets) NextResultSet() error { - if !rs.HasNextResultSet() { - return io.EOF - } +// // Implement the "RowsNextResultSet" interface +// func (rs *rowSets) NextResultSet() error { +// if !rs.HasNextResultSet() { +// return io.EOF +// } - rs.pos++ - return nil -} +// rs.pos++ +// return nil +// } // type for rows with columns definition created with pgxmock.NewRowsWithColumnDefinition type rowSetsWithDefinition struct { diff --git a/rows_test.go b/rows_test.go index cfd0778..9e726ba 100644 --- a/rows_test.go +++ b/rows_test.go @@ -5,9 +5,31 @@ import ( "errors" "fmt" "testing" + + "github.com/jackc/pgconn" ) -// const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ` +func TestExplicitTypeCasting(t *testing.T) { + mock, err := NewPool() + if err != nil { + panic(err) + } + + mock.ExpectQuery("SELECT .+ FROM test WHERE .+"). + WithArgs(uint64(1)). + WillReturnRows(NewRows( + []string{"id"}). + AddRow(uint64(1)), + ) + + rows := mock.QueryRow(context.Background(), "SELECT id FROM test WHERE id = $1", uint64(1)) + + var id uint64 + err = rows.Scan(&id) + if err != nil { + t.Error(err) + } +} func ExampleRows() { mock, err := NewConn() @@ -18,13 +40,19 @@ func ExampleRows() { rows := NewRows([]string{"id", "title"}). AddRow(1, "one"). - AddRow(2, "two") + AddRow(2, "two"). + AddCommandTag(pgconn.CommandTag("SELECT 2")) mock.ExpectQuery("SELECT").WillReturnRows(rows) rs, _ := mock.Query(context.Background(), "SELECT") defer rs.Close() + fmt.Println("command tag:", rs.CommandTag()) + if len(rs.FieldDescriptions()) != 2 { + fmt.Println("got wrong number of fields") + } + for rs.Next() { var id int var title string @@ -35,7 +63,9 @@ func ExampleRows() { if rs.Err() != nil { fmt.Println("got rows error:", rs.Err()) } - // Output: scanned id: 1 and title: one + + // Output: command tag: SELECT 2 + // scanned id: 1 and title: one // scanned id: 2 and title: two }