This repository has been archived by the owner on Mar 16, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
database_test.go
93 lines (74 loc) · 1.75 KB
/
database_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package abutil
import (
"database/sql"
"errors"
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
func mockDBContext(t *testing.T, fn func(*sql.DB, sqlmock.Sqlmock)) {
db, mock, err := sqlmock.New()
if err != nil {
t.Error(err)
}
defer db.Close()
fn(db, mock)
}
func TestRollbackErr(t *testing.T) {
mockDBContext(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
alt := errors.New("Some alternative error")
err = RollbackErr(tx, alt)
if err != alt {
t.Errorf("Expected RollbackErr to return %v, but got %v", alt, err)
}
})
mockDBContext(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rberr := errors.New("Some rollback error")
mock.ExpectBegin()
mock.ExpectRollback().
WillReturnError(rberr)
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
err = RollbackErr(tx, errors.New("This should not be used"))
if err != rberr {
t.Errorf("Expected RollbackErr to return %v, but got %v", rberr, err)
}
})
}
func exampleRollbackDBContext(fn func(*sql.DB)) {
db, mock, _ := sqlmock.New()
mock.ExpectBegin() // At least exopect the begin statement
fn(db)
db.Close()
}
func ExampleRollbackErr() {
insertSomething := func(db *sql.DB) error {
tx, _ := db.Begin()
_, err := tx.Exec("INSERT INTO some_table (some_column) VALUES (?)",
"foobar")
if err != nil {
// The old way, imagine doing this 10 times in a method
if err := tx.Rollback(); err != nil {
return err
}
return err
}
_, err = tx.Exec("DROP DATABASE foobar")
if err != nil {
// With RollbackErr
return RollbackErr(tx, err)
}
return nil
}
exampleRollbackDBContext(func(db *sql.DB) {
fmt.Println(insertSomething(db))
})
}