diff --git a/expectations_go18.go b/expectations_go18.go index 6b85ce1..767ebd4 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -30,6 +30,9 @@ func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error { if nil == e.args { + if len(args) > 0 { + return fmt.Errorf("expected 0, but got %d arguments", len(args)) + } return nil } if len(args) != len(e.args) { diff --git a/expectations_go18_test.go b/expectations_go18_test.go index 1974721..3e8821c 100644 --- a/expectations_go18_test.go +++ b/expectations_go18_test.go @@ -12,8 +12,8 @@ import ( func TestQueryExpectationArgComparison(t *testing.T) { e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} against := []driver.NamedValue{{Value: int64(5), Ordinal: 1}} - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + if err := e.argsMatches(against); err == nil { + t.Errorf("arguments should not match, since no expectation was set, but argument was passed") } e.args = []driver.Value{5, "str"} @@ -104,8 +104,8 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { func TestQueryExpectationNamedArgComparison(t *testing.T) { e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} against := []driver.NamedValue{{Value: int64(5), Name: "id"}} - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + if err := e.argsMatches(against); err == nil { + t.Errorf("arguments should not match, since no expectation was set, but argument was passed") } e.args = []driver.Value{ diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index 223e076..cf56e67 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -435,9 +435,9 @@ func TestContextExecErrorDelay(t *testing.T) { defer db.Close() // test that return of error is delayed - var delay time.Duration - delay = 100 * time.Millisecond + var delay time.Duration = 100 * time.Millisecond mock.ExpectExec("^INSERT INTO articles"). + WithArgs("hello"). WillReturnError(errors.New("slow fail")). WillDelayFor(delay) diff --git a/sqlmock_test.go b/sqlmock_test.go index ee6b516..982a32a 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -959,7 +959,7 @@ func TestPrepareExec(t *testing.T) { mock.ExpectBegin() ep := mock.ExpectPrepare("INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") for i := 0; i < 3; i++ { - ep.ExpectExec().WillReturnResult(NewResult(1, 1)) + ep.ExpectExec().WithArgs(i, "Hello"+strconv.Itoa(i)).WillReturnResult(NewResult(1, 1)) } mock.ExpectCommit() tx, _ := db.Begin() @@ -1073,7 +1073,7 @@ func TestPreparedStatementCloseExpectation(t *testing.T) { defer db.Close() ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed() - ep.ExpectExec().WillReturnResult(NewResult(1, 1)) + ep.ExpectExec().WithArgs(1, "Hello").WillReturnResult(NewResult(1, 1)) stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") if err != nil { @@ -1102,9 +1102,9 @@ func TestExecExpectationErrorDelay(t *testing.T) { defer db.Close() // test that return of error is delayed - var delay time.Duration - delay = 100 * time.Millisecond + var delay time.Duration = 100 * time.Millisecond mock.ExpectExec("^INSERT INTO articles"). + WithArgs("hello"). WillReturnError(errors.New("slow fail")). WillDelayFor(delay) @@ -1230,10 +1230,10 @@ func Test_sqlmock_Prepare_and_Exec(t *testing.T) { mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)") expected := NewResult(1, 1) - mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)").WithArgs("test"). WillReturnResult(expected) expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") - mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows) got, err := mock.(*sqlmock).Prepare(query) if err != nil { @@ -1326,7 +1326,7 @@ func Test_sqlmock_Query(t *testing.T) { } defer db.Close() expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") - mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows) query := "SELECT name, email FROM users WHERE name = ?" rows, err := mock.(*sqlmock).Query(query, []driver.Value{"test"}) if err != nil {