From 042fe86dc2daa9684f746ab337971e84147e8856 Mon Sep 17 00:00:00 2001 From: chris erway Date: Mon, 29 Oct 2018 15:54:40 -0400 Subject: [PATCH] Fix ability to pass TracingEventReceiver to dbr (#154) * add failing tests that exercise TracingEventReceiver * pass EventReceiver explicitly between types rather than embedded parent struct * Code review comments for #154 * address code review comments for #154 --- dbr_test.go | 8 +++++++- delete.go | 8 ++++---- event_test.go | 18 ++++++++++++++++++ insert.go | 8 ++++---- insert_test.go | 7 +++++++ select.go | 8 ++++---- transaction.go | 2 +- transaction_test.go | 8 ++++++++ update.go | 8 ++++---- 9 files changed, 57 insertions(+), 18 deletions(-) create mode 100644 event_test.go diff --git a/dbr_test.go b/dbr_test.go index 78402443..4cf9dca5 100644 --- a/dbr_test.go +++ b/dbr_test.go @@ -25,7 +25,7 @@ var ( ) func createSession(driver, dsn string) *Session { - conn, err := Open(driver, dsn, nil) + conn, err := Open(driver, dsn, &testTraceReceiver{}) if err != nil { panic(err) } @@ -88,6 +88,8 @@ func reset(t *testing.T, sess *Session) { _, err := sess.Exec(v) require.NoError(t, err) } + // clear test data collected by testTraceReceiver + sess.EventReceiver = &testTraceReceiver{} } func TestBasicCRUD(t *testing.T) { @@ -174,15 +176,19 @@ func TestTimeout(t *testing.T) { var people []dbrPerson _, err := sess.Select("*").From("dbr_people").Load(&people) require.Equal(t, context.DeadlineExceeded, err) + require.Equal(t, 1, sess.EventReceiver.(*testTraceReceiver).errored) _, err = sess.InsertInto("dbr_people").Columns("name", "email").Values("test", "test@test.com").Exec() require.Equal(t, context.DeadlineExceeded, err) + require.Equal(t, 2, sess.EventReceiver.(*testTraceReceiver).errored) _, err = sess.Update("dbr_people").Set("name", "test1").Exec() require.Equal(t, context.DeadlineExceeded, err) + require.Equal(t, 3, sess.EventReceiver.(*testTraceReceiver).errored) _, err = sess.DeleteFrom("dbr_people").Exec() require.Equal(t, context.DeadlineExceeded, err) + require.Equal(t, 4, sess.EventReceiver.(*testTraceReceiver).errored) // tx op timeout sess.Timeout = 0 diff --git a/delete.go b/delete.go index 8e0757ed..ac059fb4 100644 --- a/delete.go +++ b/delete.go @@ -59,7 +59,7 @@ func DeleteFrom(table string) *DeleteStmt { func (sess *Session) DeleteFrom(table string) *DeleteStmt { b := DeleteFrom(table) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -68,7 +68,7 @@ func (sess *Session) DeleteFrom(table string) *DeleteStmt { func (tx *Tx) DeleteFrom(table string) *DeleteStmt { b := DeleteFrom(table) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } @@ -88,7 +88,7 @@ func DeleteBySql(query string, value ...interface{}) *DeleteStmt { func (sess *Session) DeleteBySql(query string, value ...interface{}) *DeleteStmt { b := DeleteBySql(query, value...) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -97,7 +97,7 @@ func (sess *Session) DeleteBySql(query string, value ...interface{}) *DeleteStmt func (tx *Tx) DeleteBySql(query string, value ...interface{}) *DeleteStmt { b := DeleteBySql(query, value...) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } diff --git a/event_test.go b/event_test.go new file mode 100644 index 00000000..8a99a867 --- /dev/null +++ b/event_test.go @@ -0,0 +1,18 @@ +package dbr + +import ( + "context" +) + +type testTraceReceiver struct { + NullEventReceiver + started []struct{ eventName, query string } + errored, finished int +} + +func (t *testTraceReceiver) SpanStart(ctx context.Context, eventName, query string) context.Context { + t.started = append(t.started, struct{ eventName, query string }{eventName, query}) + return ctx +} +func (t *testTraceReceiver) SpanError(ctx context.Context, err error) { t.errored++ } +func (t *testTraceReceiver) SpanFinish(ctx context.Context) { t.finished++ } diff --git a/insert.go b/insert.go index b5a0da82..e98a4e12 100644 --- a/insert.go +++ b/insert.go @@ -88,7 +88,7 @@ func InsertInto(table string) *InsertStmt { func (sess *Session) InsertInto(table string) *InsertStmt { b := InsertInto(table) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -97,7 +97,7 @@ func (sess *Session) InsertInto(table string) *InsertStmt { func (tx *Tx) InsertInto(table string) *InsertStmt { b := InsertInto(table) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } @@ -116,7 +116,7 @@ func InsertBySql(query string, value ...interface{}) *InsertStmt { func (sess *Session) InsertBySql(query string, value ...interface{}) *InsertStmt { b := InsertBySql(query, value...) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -125,7 +125,7 @@ func (sess *Session) InsertBySql(query string, value ...interface{}) *InsertStmt func (tx *Tx) InsertBySql(query string, value ...interface{}) *InsertStmt { b := InsertBySql(query, value...) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } diff --git a/insert_test.go b/insert_test.go index ac363e24..24590f0d 100644 --- a/insert_test.go +++ b/insert_test.go @@ -33,6 +33,13 @@ func TestPostgresReturning(t *testing.T) { Returning("id").Load(&person.Id) require.NoError(t, err) require.True(t, person.Id > 0) + require.Len(t, sess.EventReceiver.(*testTraceReceiver).started, 1) + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].eventName, "dbr.select") + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "INSERT") + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "dbr_people") + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "name") + require.Equal(t, 1, sess.EventReceiver.(*testTraceReceiver).finished) + require.Equal(t, 0, sess.EventReceiver.(*testTraceReceiver).errored) } func BenchmarkInsertValuesSQL(b *testing.B) { diff --git a/select.go b/select.go index 8e051642..da7b55e1 100644 --- a/select.go +++ b/select.go @@ -155,7 +155,7 @@ func prepareSelect(a []string) []interface{} { func (sess *Session) Select(column ...string) *SelectStmt { b := Select(prepareSelect(column)...) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -164,7 +164,7 @@ func (sess *Session) Select(column ...string) *SelectStmt { func (tx *Tx) Select(column ...string) *SelectStmt { b := Select(prepareSelect(column)...) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } @@ -185,7 +185,7 @@ func SelectBySql(query string, value ...interface{}) *SelectStmt { func (sess *Session) SelectBySql(query string, value ...interface{}) *SelectStmt { b := SelectBySql(query, value...) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -194,7 +194,7 @@ func (sess *Session) SelectBySql(query string, value ...interface{}) *SelectStmt func (tx *Tx) SelectBySql(query string, value ...interface{}) *SelectStmt { b := SelectBySql(query, value...) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } diff --git a/transaction.go b/transaction.go index e4f8f2ea..2c150b4a 100644 --- a/transaction.go +++ b/transaction.go @@ -28,7 +28,7 @@ func (sess *Session) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, err sess.Event("dbr.begin") return &Tx{ - EventReceiver: sess, + EventReceiver: sess.EventReceiver, Dialect: sess.Dialect, Tx: tx, Timeout: sess.GetTimeout(), diff --git a/transaction_test.go b/transaction_test.go index b2819681..9f38c1ac 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -18,6 +18,13 @@ func TestTransactionCommit(t *testing.T) { result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() require.NoError(t, err) + require.Len(t, sess.EventReceiver.(*testTraceReceiver).started, 1) + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].eventName, "dbr.exec") + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "INSERT") + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "dbr_people") + require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "name") + require.Equal(t, 1, sess.EventReceiver.(*testTraceReceiver).finished) + require.Equal(t, 0, sess.EventReceiver.(*testTraceReceiver).errored) rowsAffected, err := result.RowsAffected() require.NoError(t, err) @@ -29,6 +36,7 @@ func TestTransactionCommit(t *testing.T) { var person dbrPerson err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadOne(&person) require.Error(t, err) + require.Equal(t, 1, sess.EventReceiver.(*testTraceReceiver).errored) } } diff --git a/update.go b/update.go index 81dd21d5..1b983a25 100644 --- a/update.go +++ b/update.go @@ -81,7 +81,7 @@ func Update(table string) *UpdateStmt { func (sess *Session) Update(table string) *UpdateStmt { b := Update(table) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -90,7 +90,7 @@ func (sess *Session) Update(table string) *UpdateStmt { func (tx *Tx) Update(table string) *UpdateStmt { b := Update(table) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b } @@ -111,7 +111,7 @@ func UpdateBySql(query string, value ...interface{}) *UpdateStmt { func (sess *Session) UpdateBySql(query string, value ...interface{}) *UpdateStmt { b := UpdateBySql(query, value...) b.runner = sess - b.EventReceiver = sess + b.EventReceiver = sess.EventReceiver b.Dialect = sess.Dialect return b } @@ -120,7 +120,7 @@ func (sess *Session) UpdateBySql(query string, value ...interface{}) *UpdateStmt func (tx *Tx) UpdateBySql(query string, value ...interface{}) *UpdateStmt { b := UpdateBySql(query, value...) b.runner = tx - b.EventReceiver = tx + b.EventReceiver = tx.EventReceiver b.Dialect = tx.Dialect return b }