From 1d5d6ea2708e5b1339751d8a4c996279694d43d6 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Tue, 7 Aug 2018 11:20:43 +0800 Subject: [PATCH 1/3] module/apmsql: fix Stmt.NameValueChecker Fix Stmt.NameValueChecker so that it defers to Conn.NameValueChecker if the underlying statement does not implement the interface. The database/sql code will only call the conn method if the stmt doesn't implement the interface. Also, expose QuerySignature. --- module/apmsql/conn.go | 6 ++++++ module/apmsql/driver.go | 2 +- module/apmsql/gofuzz_signature.go | 10 +++++++--- module/apmsql/signature.go | 4 ++-- module/apmsql/signature_test.go | 8 +++++--- module/apmsql/stmt.go | 17 ++++++++++++----- module/apmsql/stmt_go19.go | 20 -------------------- module/apmsql/utils.go | 13 +++++++++++++ 8 files changed, 46 insertions(+), 34 deletions(-) delete mode 100644 module/apmsql/stmt_go19.go diff --git a/module/apmsql/conn.go b/module/apmsql/conn.go index 0dc8a97a9..c0d79fffa 100644 --- a/module/apmsql/conn.go +++ b/module/apmsql/conn.go @@ -11,6 +11,7 @@ import ( func newConn(in driver.Conn, d *tracingDriver, dsnInfo DSNInfo) driver.Conn { conn := &conn{Conn: in, driver: d} conn.dsnInfo = dsnInfo + conn.namedValueChecker, _ = in.(namedValueChecker) conn.pinger, _ = in.(driver.Pinger) conn.queryer, _ = in.(driver.Queryer) conn.queryerContext, _ = in.(driver.QueryerContext) @@ -31,6 +32,7 @@ type conn struct { driver *tracingDriver dsnInfo DSNInfo + namedValueChecker namedValueChecker pinger driver.Pinger queryer driver.Queryer queryerContext driver.QueryerContext @@ -157,6 +159,10 @@ func (*conn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, errors.New("Exec should never be called") } +func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { + return checkNamedValue(nv, c.namedValueChecker) +} + type connBeginTx struct { *conn connBeginTx driver.ConnBeginTx diff --git a/module/apmsql/driver.go b/module/apmsql/driver.go index 4d9fb58fb..18dcff626 100644 --- a/module/apmsql/driver.go +++ b/module/apmsql/driver.go @@ -97,7 +97,7 @@ func (d *tracingDriver) formatSpanType(suffix string) string { // querySignature returns the value to use in Span.Name for // a database query. func (d *tracingDriver) querySignature(query string) string { - return genericQuerySignature(query) + return QuerySignature(query) } func (d *tracingDriver) Open(name string) (driver.Conn, error) { diff --git a/module/apmsql/gofuzz_signature.go b/module/apmsql/gofuzz_signature.go index e5a602569..6fdd2afcf 100644 --- a/module/apmsql/gofuzz_signature.go +++ b/module/apmsql/gofuzz_signature.go @@ -1,12 +1,16 @@ // +build gofuzz -package apmsql +package apmsql_test -import "strings" +import ( + "strings" + + "github.com/elastic/apm-agent-go/module/apmsql" +) func Fuzz(data []byte) int { sql := string(data) - sig := genericQuerySignature(sql) + sig := apmsql.QuerySignature(sql) if sig == "" { return -1 } diff --git a/module/apmsql/signature.go b/module/apmsql/signature.go index f3889671e..da4ca171f 100644 --- a/module/apmsql/signature.go +++ b/module/apmsql/signature.go @@ -6,7 +6,7 @@ import ( "github.com/elastic/apm-agent-go/internal/sqlscanner" ) -// genericQuerySignature returns the "signature" for a query: +// QuerySignature returns the "signature" for a query: // a high level description of the operation. // // For DDL statements (CREATE, DROP, ALTER, etc.), we we only @@ -15,7 +15,7 @@ import ( // an application. For SELECT, INSERT, and UPDATE, and DELETE, // we attempt to extract the first table name. If we are unable // to identify the table name, we simply omit it. -func genericQuerySignature(query string) string { +func QuerySignature(query string) string { s := sqlscanner.NewScanner(query) for s.Scan() { if s.Token() != sqlscanner.COMMENT { diff --git a/module/apmsql/signature_test.go b/module/apmsql/signature_test.go index 87e58cc71..9bc0918b2 100644 --- a/module/apmsql/signature_test.go +++ b/module/apmsql/signature_test.go @@ -1,14 +1,16 @@ -package apmsql +package apmsql_test import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/elastic/apm-agent-go/module/apmsql" ) func TestQuerySignature(t *testing.T) { assertSignatureEqual := func(expect, stmt string) { - out := genericQuerySignature(stmt) + out := apmsql.QuerySignature(stmt) assert.Equal(t, expect, out, "%s", stmt) } @@ -55,7 +57,7 @@ func TestQuerySignature(t *testing.T) { func BenchmarkQuerySignature(b *testing.B) { sql := "SELECT *,(SELECT COUNT(*) FROM table2 WHERE table2.field1 = table1.id) AS count FROM table1 WHERE table1.field1 = 'value'" for i := 0; i < b.N; i++ { - signature := genericQuerySignature(sql) + signature := apmsql.QuerySignature(sql) if signature != "SELECT FROM table1" { panic("unexpected result: " + signature) } diff --git a/module/apmsql/stmt.go b/module/apmsql/stmt.go index 5b91ff7e6..48875fe95 100644 --- a/module/apmsql/stmt.go +++ b/module/apmsql/stmt.go @@ -17,20 +17,23 @@ func newStmt(in driver.Stmt, conn *conn, query string) driver.Stmt { stmt.columnConverter, _ = in.(driver.ColumnConverter) stmt.stmtExecContext, _ = in.(driver.StmtExecContext) stmt.stmtQueryContext, _ = in.(driver.StmtQueryContext) - stmt.stmtGo19.init(in) + stmt.namedValueChecker, _ = in.(namedValueChecker) + if stmt.namedValueChecker == nil { + stmt.namedValueChecker = conn.namedValueChecker + } return stmt } type stmt struct { driver.Stmt - stmtGo19 conn *conn signature string query string - columnConverter driver.ColumnConverter - stmtExecContext driver.StmtExecContext - stmtQueryContext driver.StmtQueryContext + columnConverter driver.ColumnConverter + namedValueChecker namedValueChecker + stmtExecContext driver.StmtExecContext + stmtQueryContext driver.StmtQueryContext } func (s *stmt) startSpan(ctx context.Context, spanType string) (*elasticapm.Span, context.Context) { @@ -79,3 +82,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ dr } return s.Query(dargs) } + +func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error { + return checkNamedValue(nv, s.namedValueChecker) +} diff --git a/module/apmsql/stmt_go19.go b/module/apmsql/stmt_go19.go deleted file mode 100644 index a9189ba11..000000000 --- a/module/apmsql/stmt_go19.go +++ /dev/null @@ -1,20 +0,0 @@ -// +build go1.9 - -package apmsql - -import "database/sql/driver" - -type stmtGo19 struct { - namedValueChecker driver.NamedValueChecker -} - -func (s *stmtGo19) init(in driver.Stmt) { - s.namedValueChecker, _ = in.(driver.NamedValueChecker) -} - -func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error { - if s.namedValueChecker != nil { - return s.namedValueChecker.CheckNamedValue(nv) - } - return driver.ErrSkip -} diff --git a/module/apmsql/utils.go b/module/apmsql/utils.go index 87bbe3f1f..4c430bc7a 100644 --- a/module/apmsql/utils.go +++ b/module/apmsql/utils.go @@ -16,3 +16,16 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { } return dargs, nil } + +// namedValueChecker is identical to driver.NamedValueChecker, existing +// for compatibility with Go 1.8. +type namedValueChecker interface { + CheckNamedValue(*driver.NamedValue) error +} + +func checkNamedValue(nv *driver.NamedValue, next namedValueChecker) error { + if next != nil { + return next.CheckNamedValue(nv) + } + return driver.ErrSkip +} From 15f0117d15e9e051a385480654830008aed21374 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Tue, 7 Aug 2018 14:27:16 +0800 Subject: [PATCH 2/3] module/apmsql: expose utilities for apmgorm Move some internal utilities into internal/sqlutil. Expose a new function, apmsql.DriverDSNParser, which returns the DSNParserFunc recorded for a previously registered driver, or the "generic" parser. These bits will be used by a GORM module. --- internal/sqlutil/doc.go | 2 + internal/sqlutil/drivername.go | 32 ++++++++++ .../apmsql => internal/sqlutil}/signature.go | 2 +- .../sqlutil}/signature_test.go | 8 +-- module/apmsql/driver.go | 60 ++++++++++--------- 5 files changed, 72 insertions(+), 32 deletions(-) create mode 100644 internal/sqlutil/doc.go create mode 100644 internal/sqlutil/drivername.go rename {module/apmsql => internal/sqlutil}/signature.go (99%) rename {module/apmsql => internal/sqlutil}/signature_test.go (94%) diff --git a/internal/sqlutil/doc.go b/internal/sqlutil/doc.go new file mode 100644 index 000000000..a9d76b061 --- /dev/null +++ b/internal/sqlutil/doc.go @@ -0,0 +1,2 @@ +// Package sqlutil provides utilities to SQL-related instrumentation modules. +package sqlutil diff --git a/internal/sqlutil/drivername.go b/internal/sqlutil/drivername.go new file mode 100644 index 000000000..753286e2e --- /dev/null +++ b/internal/sqlutil/drivername.go @@ -0,0 +1,32 @@ +package sqlutil + +import ( + "database/sql/driver" + "reflect" + "strings" +) + +// DriverName returns the name of the driver, based on its type. +// If the driver name cannot be deduced, DriverName will return +// "generic". +func DriverName(d driver.Driver) string { + t := reflect.TypeOf(d) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t.Name() { + case "SQLiteDriver": + return "sqlite3" + case "MySQLDriver": + return "mysql" + case "Driver": + // Check suffix in case of vendoring. + if strings.HasSuffix(t.PkgPath(), "github.com/lib/pq") { + return "postgresql" + } + } + // TODO include the package path of the driver in context + // so we can easily determine how the rules above should + // be updated. + return "generic" +} diff --git a/module/apmsql/signature.go b/internal/sqlutil/signature.go similarity index 99% rename from module/apmsql/signature.go rename to internal/sqlutil/signature.go index da4ca171f..4a2e33e53 100644 --- a/module/apmsql/signature.go +++ b/internal/sqlutil/signature.go @@ -1,4 +1,4 @@ -package apmsql +package sqlutil import ( "strings" diff --git a/module/apmsql/signature_test.go b/internal/sqlutil/signature_test.go similarity index 94% rename from module/apmsql/signature_test.go rename to internal/sqlutil/signature_test.go index 9bc0918b2..fd3ce83fb 100644 --- a/module/apmsql/signature_test.go +++ b/internal/sqlutil/signature_test.go @@ -1,16 +1,16 @@ -package apmsql_test +package sqlutil_test import ( "testing" "github.com/stretchr/testify/assert" - "github.com/elastic/apm-agent-go/module/apmsql" + "github.com/elastic/apm-agent-go/internal/sqlutil" ) func TestQuerySignature(t *testing.T) { assertSignatureEqual := func(expect, stmt string) { - out := apmsql.QuerySignature(stmt) + out := sqlutil.QuerySignature(stmt) assert.Equal(t, expect, out, "%s", stmt) } @@ -57,7 +57,7 @@ func TestQuerySignature(t *testing.T) { func BenchmarkQuerySignature(b *testing.B) { sql := "SELECT *,(SELECT COUNT(*) FROM table2 WHERE table2.field1 = table1.id) AS count FROM table1 WHERE table1.field1 = 'value'" for i := 0; i < b.N; i++ { - signature := apmsql.QuerySignature(sql) + signature := sqlutil.QuerySignature(sql) if signature != "SELECT FROM table1" { panic("unexpected result: " + signature) } diff --git a/module/apmsql/driver.go b/module/apmsql/driver.go index 18dcff626..0839b1bae 100644 --- a/module/apmsql/driver.go +++ b/module/apmsql/driver.go @@ -4,22 +4,32 @@ import ( "database/sql" "database/sql/driver" "fmt" - "reflect" - "strings" + "sync" + + "github.com/elastic/apm-agent-go/internal/sqlutil" ) // DriverPrefix should be used as a driver name prefix when // registering via sql.Register. const DriverPrefix = "elasticapm/" +var ( + driversMu sync.RWMutex + drivers = make(map[string]*tracingDriver) +) + // Register registers a traced version of the given driver. // // The name and driver values should be the same as given to // sql.Register: the name of the driver (e.g. "postgres"), and // the driver (e.g. &github.com/lib/pq.Driver{}). func Register(name string, driver driver.Driver, opts ...WrapOption) { - wrapped := Wrap(driver, opts...) + driversMu.Lock() + defer driversMu.Unlock() + + wrapped := newTracingDriver(driver, opts...) sql.Register(DriverPrefix+name, wrapped) + drivers[name] = wrapped } // Open opens a database with the given driver and data source names, @@ -34,6 +44,10 @@ func Open(driverName, dataSourceName string) (*sql.DB, error) { // will be obtained from the context supplied to methods // that accept it. func Wrap(driver driver.Driver, opts ...WrapOption) driver.Driver { + return newTracingDriver(driver, opts...) +} + +func newTracingDriver(driver driver.Driver, opts ...WrapOption) *tracingDriver { d := &tracingDriver{ Driver: driver, } @@ -41,7 +55,7 @@ func Wrap(driver driver.Driver, opts ...WrapOption) driver.Driver { opt(d) } if d.driverName == "" { - d.driverName = driverName(driver) + d.driverName = sqlutil.DriverName(driver) } if d.dsnParser == nil { d.dsnParser = genericDSNParser @@ -55,6 +69,20 @@ func Wrap(driver driver.Driver, opts ...WrapOption) driver.Driver { return d } +// DriverDSNParser returns the DSNParserFunc for the registered driver. +// If there is no such registered driver, the parser function that is +// returned will return empty DSNInfo structures. +func DriverDSNParser(driverName string) DSNParserFunc { + driversMu.RLock() + driver := drivers[driverName] + defer driversMu.RUnlock() + + if driver == nil { + return genericDSNParser + } + return driver.dsnParser +} + // WrapOption is an option that can be supplied to Wrap. type WrapOption func(*tracingDriver) @@ -97,7 +125,7 @@ func (d *tracingDriver) formatSpanType(suffix string) string { // querySignature returns the value to use in Span.Name for // a database query. func (d *tracingDriver) querySignature(query string) string { - return QuerySignature(query) + return sqlutil.QuerySignature(query) } func (d *tracingDriver) Open(name string) (driver.Conn, error) { @@ -107,25 +135,3 @@ func (d *tracingDriver) Open(name string) (driver.Conn, error) { } return newConn(conn, d, d.dsnParser(name)), nil } - -func driverName(d driver.Driver) string { - t := reflect.TypeOf(d) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - switch t.Name() { - case "SQLiteDriver": - return "sqlite3" - case "MySQLDriver": - return "mysql" - case "Driver": - // Check suffix in case of vendoring. - if strings.HasSuffix(t.PkgPath(), "github.com/lib/pq") { - return "postgresql" - } - } - // TODO include the package path of the driver in context - // so we can easily determine how the rules above should - // be updated. - return "generic" -} From 4469614cabd16b419d6cc89fd21de9d3a4a34292 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Tue, 7 Aug 2018 14:49:05 +0800 Subject: [PATCH 3/3] module/apmgorm: introduce GORM instrumentation Package apmgorm provides a means of instrumenting [gorm](http://gorm.io) database operations. By using apmgorm.Open instead of gorm.Open, we obtain a *gorm.DB which can be used with apmgorm.WithContext. The WithContext function can be used to propagate context (i.e. containing a transaction) to callbacks which report spans. We provide "dialects" packages which import the gorm/dialects namesake packages, as well as register the apmsql drivers. The latter is required in order to parse DSNs. --- CHANGELOG.md | 1 + docs/instrumenting.asciidoc | 25 ++++ module/apmgorm/apmgorm_test.go | 182 +++++++++++++++++++++++ module/apmgorm/context.go | 130 ++++++++++++++++ module/apmgorm/dialects/mysql/init.go | 9 ++ module/apmgorm/dialects/postgres/init.go | 9 ++ module/apmgorm/dialects/sqlite/init.go | 9 ++ module/apmgorm/doc.go | 2 + module/apmgorm/open.go | 43 ++++++ scripts/Dockerfile-testing | 4 + 10 files changed, 414 insertions(+) create mode 100644 module/apmgorm/apmgorm_test.go create mode 100644 module/apmgorm/context.go create mode 100644 module/apmgorm/dialects/mysql/init.go create mode 100644 module/apmgorm/dialects/postgres/init.go create mode 100644 module/apmgorm/dialects/sqlite/init.go create mode 100644 module/apmgorm/doc.go create mode 100644 module/apmgorm/open.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 04fec7144..b7e40c409 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - Add `ELASTIC_APM_IGNORE_URLS` config (#158) - module/apmsql: fix a bug preventing errors from being captured (#160) - Introduce `Tracer.StartTransactionOptions`, drop variadic args from `Tracer.StartTransaction` (#165) + - module/apmgorm: introduce GORM instrumentation module (#169, #170) ## [v0.4.0](https://github.com/elastic/apm-agent-go/releases/tag/v0.4.0) diff --git a/docs/instrumenting.asciidoc b/docs/instrumenting.asciidoc index 9843c5cd4..c05f4ba90 100644 --- a/docs/instrumenting.asciidoc +++ b/docs/instrumenting.asciidoc @@ -264,6 +264,31 @@ func main() { Spans will be created for queries and other statement executions if the context methods are used, and the context includes a transaction. +===== module/apmgorm +Package apmgorm provides a means of instrumenting [gorm](http://gorm.io) database operations. + +To trace GORM operations, import the appropriate `apmgorm/dialects` package (instead of the +`gorm/dialects` package), and use `apmgorm.Open` (instead of `gorm.Open`). The parameters are +exactly the same. + +Once you have a `*gorm.DB` from `apmgorm.Open`, you can call `apmgorm.WithContext` to +propagate a context containing a transaction to the operations: + +[source,go] +---- +import ( + "github.com/elastic/apm-agent-go/module/apmgorm" + _ "github.com/elastic/apm-agent-go/module/apmgorm/dialects/postgres" +) + +func main() { + db, err := apmgorm.Open("postgres", "") + ... + db = apmgorm.WithContext(ctx, db) + db.Find(...) // creates a "SELECT FROM " span +} +---- + ===== module/apmgocql Package apmgocql provides a means of instrumenting https://github.com/gocql/gocql[gocql] so that queries are reported as spans within the current transaction. diff --git a/module/apmgorm/apmgorm_test.go b/module/apmgorm/apmgorm_test.go new file mode 100644 index 000000000..99413f5d1 --- /dev/null +++ b/module/apmgorm/apmgorm_test.go @@ -0,0 +1,182 @@ +package apmgorm_test + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/apm-agent-go/apmtest" + "github.com/elastic/apm-agent-go/module/apmgorm" + _ "github.com/elastic/apm-agent-go/module/apmgorm/dialects/mysql" + _ "github.com/elastic/apm-agent-go/module/apmgorm/dialects/postgres" + _ "github.com/elastic/apm-agent-go/module/apmgorm/dialects/sqlite" + "github.com/elastic/apm-agent-go/module/apmsql" +) + +type Product struct { + gorm.Model + Code string + Price uint +} + +func TestWithContext(t *testing.T) { + t.Run("sqlite3", func(t *testing.T) { + testWithContext(t, + apmsql.DSNInfo{Database: ":memory:"}, + "sqlite3", ":memory:", + ) + }) + + if os.Getenv("PGHOST") == "" { + t.Logf("PGHOST not specified, skipping") + } else { + t.Run("postgres", func(t *testing.T) { + testWithContext(t, + apmsql.DSNInfo{Database: "test_db", User: "postgres"}, + "postgres", "user=postgres password=hunter2 dbname=test_db sslmode=disable", + ) + }) + } + + if mysqlHost := os.Getenv("MYSQL_HOST"); mysqlHost == "" { + t.Logf("MYSQL_HOST not specified, skipping") + } else { + t.Run("mysql", func(t *testing.T) { + testWithContext(t, + apmsql.DSNInfo{Database: "test_db", User: "root"}, + "mysql", "root:hunter2@tcp("+mysqlHost+")/test_db?parseTime=true", + ) + }) + } +} + +func testWithContext(t *testing.T, dsnInfo apmsql.DSNInfo, dialect string, args ...interface{}) { + tx, errors := apmtest.WithTransaction(func(ctx context.Context) { + db, err := apmgorm.Open(dialect, args...) + require.NoError(t, err) + defer db.Close() + db = apmgorm.WithContext(ctx, db) + + db.AutoMigrate(&Product{}) + db.Create(&Product{Code: "L1212", Price: 1000}) + + var product Product + assert.NoError(t, db.First(&product, "code = ?", "L1212").Error) + assert.NoError(t, db.Model(&product).Update("Price", 2000).Error) + assert.NoError(t, db.Delete(&product).Error) // soft + assert.NoError(t, db.Unscoped().Delete(&product).Error) // hard + }) + require.NotEmpty(t, tx.Spans) + assert.Empty(t, errors) + + spanNames := make([]string, len(tx.Spans)) + for i, span := range tx.Spans { + spanNames[i] = span.Name + require.NotNil(t, span.Context) + require.NotNil(t, span.Context.Database) + assert.Equal(t, dsnInfo.Database, span.Context.Database.Instance) + assert.NotEmpty(t, span.Context.Database.Statement) + assert.Equal(t, "sql", span.Context.Database.Type) + assert.Equal(t, dsnInfo.User, span.Context.Database.User) + } + assert.Equal(t, []string{ + "INSERT INTO products", + "SELECT FROM products", + "UPDATE products", + "UPDATE products", // soft delete + "DELETE FROM products", + }, spanNames) +} + +// TestWithContextNoTransaction checks that using WithContext without +// a transaction won't cause any issues. +func TestWithContextNoTransaction(t *testing.T) { + db, err := apmgorm.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + db = apmgorm.WithContext(context.Background(), db) + + db.AutoMigrate(&Product{}) + db.Create(&Product{Code: "L1212", Price: 1000}) + + var product Product + assert.NoError(t, db.Where("code=?", "L1212").First(&product).Error) +} + +func TestWithContextNonSampled(t *testing.T) { + os.Setenv("ELASTIC_APM_TRANSACTION_SAMPLE_RATE", "0") + defer os.Unsetenv("ELASTIC_APM_TRANSACTION_SAMPLE_RATE") + + db, err := apmgorm.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + db.AutoMigrate(&Product{}) + + tx, _ := apmtest.WithTransaction(func(ctx context.Context) { + db = apmgorm.WithContext(ctx, db) + db.Create(&Product{Code: "L1212", Price: 1000}) + }) + require.Empty(t, tx.Spans) +} + +func TestCaptureErrors(t *testing.T) { + db, err := apmgorm.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + db.SetLogger(nopLogger{}) + db.AutoMigrate(&Product{}) + + tx, errors := apmtest.WithTransaction(func(ctx context.Context) { + db = apmgorm.WithContext(ctx, db) + + // record not found should not cause an error + db.Where("code=?", "L1212").First(&Product{}) + + // invalid SQL should + db.Where("bananas").First(&Product{}) + }) + assert.Len(t, tx.Spans, 2) + require.Len(t, errors, 1) + assert.Regexp(t, "no such column: bananas", errors[0].Exception.Message) +} + +func TestOpenWithDriver(t *testing.T) { + db, err := apmgorm.Open("sqlite3", "sqlite3", ":memory:") + require.NoError(t, err) + defer db.Close() + db.AutoMigrate(&Product{}) + + tx, _ := apmtest.WithTransaction(func(ctx context.Context) { + db = apmgorm.WithContext(ctx, db) + db.Create(&Product{Code: "L1212", Price: 1000}) + }) + require.Len(t, tx.Spans, 1) + assert.Equal(t, ":memory:", tx.Spans[0].Context.Database.Instance) +} + +func TestOpenWithDB(t *testing.T) { + sqldb, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + defer sqldb.Close() + + db, err := apmgorm.Open("sqlite3", sqldb) + require.NoError(t, err) + defer db.Close() + db.AutoMigrate(&Product{}) + + tx, _ := apmtest.WithTransaction(func(ctx context.Context) { + db = apmgorm.WithContext(ctx, db) + db.Create(&Product{Code: "L1212", Price: 1000}) + }) + require.Len(t, tx.Spans, 1) + assert.Empty(t, tx.Spans[0].Context.Database.Instance) // no DSN info +} + +type nopLogger struct{} + +func (nopLogger) Print(v ...interface{}) {} diff --git a/module/apmgorm/context.go b/module/apmgorm/context.go new file mode 100644 index 000000000..8dee7af66 --- /dev/null +++ b/module/apmgorm/context.go @@ -0,0 +1,130 @@ +package apmgorm + +import ( + "context" + "fmt" + + "github.com/jinzhu/gorm" + + "github.com/elastic/apm-agent-go" + "github.com/elastic/apm-agent-go/internal/sqlutil" + "github.com/elastic/apm-agent-go/module/apmsql" +) + +const ( + apmContextKey = "elasticapm:context" +) + +// WithContext returns a copy of db with ctx recorded for use by +// the callbacks registered via RegisterCallbacks. +func WithContext(ctx context.Context, db *gorm.DB) *gorm.DB { + return db.Set(apmContextKey, ctx) +} + +func scopeContext(scope *gorm.Scope) (context.Context, bool) { + value, ok := scope.Get(apmContextKey) + if !ok { + return nil, false + } + ctx, _ := value.(context.Context) + return ctx, ctx != nil +} + +// RegisterCallbacks registers callbacks on db for reporting spans +// to Elastic APM. This is called automatically by apmgorm.Open; +// it is provided for cases where a *gorm.DB is acquired by other +// means. +func RegisterCallbacks(db *gorm.DB) { + registerCallbacks(db, apmsql.DSNInfo{}) +} + +func registerCallbacks(db *gorm.DB, dsnInfo apmsql.DSNInfo) { + driverName := db.Dialect().GetName() + switch driverName { + case "postgres": + driverName = "postgresql" + } + spanTypePrefix := fmt.Sprintf("db.%s.", driverName) + querySpanType := spanTypePrefix + "query" + execSpanType := spanTypePrefix + "exec" + + type params struct { + spanType string + processor func() *gorm.CallbackProcessor + } + callbacks := map[string]params{ + "gorm:create": { + spanType: execSpanType, + processor: func() *gorm.CallbackProcessor { return db.Callback().Create() }, + }, + "gorm:delete": { + spanType: execSpanType, + processor: func() *gorm.CallbackProcessor { return db.Callback().Delete() }, + }, + "gorm:query": { + spanType: querySpanType, + processor: func() *gorm.CallbackProcessor { return db.Callback().Query() }, + }, + "gorm:update": { + spanType: execSpanType, + processor: func() *gorm.CallbackProcessor { return db.Callback().Update() }, + }, + } + for name, params := range callbacks { + const callbackPrefix = "elasticapm" + params.processor().Before(name).Register( + fmt.Sprintf("%s:before:%s", callbackPrefix, name), + newBeforeCallback(params.spanType), + ) + params.processor().After(name).Register( + fmt.Sprintf("%s:after:%s", callbackPrefix, name), + newAfterCallback(dsnInfo), + ) + } +} + +func newBeforeCallback(spanType string) func(*gorm.Scope) { + return func(scope *gorm.Scope) { + ctx, ok := scopeContext(scope) + if !ok { + return + } + span, ctx := elasticapm.StartSpan(ctx, "", spanType) + if span.Dropped() { + span.End() + ctx = nil + } + scope.Set(apmContextKey, ctx) + } +} + +func newAfterCallback(dsnInfo apmsql.DSNInfo) func(*gorm.Scope) { + return func(scope *gorm.Scope) { + ctx, ok := scopeContext(scope) + if !ok { + return + } + span := elasticapm.SpanFromContext(ctx) + if span == nil { + return + } + span.Name = sqlutil.QuerySignature(scope.SQL) + span.Context.SetDatabase(elasticapm.DatabaseSpanContext{ + Instance: dsnInfo.Database, + Statement: scope.SQL, + Type: "sql", + User: dsnInfo.User, + }) + span.End() + + // Capture errors, except for "record not found", which may be expected. + for _, err := range scope.DB().GetErrors() { + if gorm.IsRecordNotFoundError(err) { + continue + } + if e := elasticapm.CaptureError(ctx, err); e != nil { + e.Send() + } + } + } +} diff --git a/module/apmgorm/dialects/mysql/init.go b/module/apmgorm/dialects/mysql/init.go new file mode 100644 index 000000000..b9279eb38 --- /dev/null +++ b/module/apmgorm/dialects/mysql/init.go @@ -0,0 +1,9 @@ +// Package apmgormmysql imports the gorm mysql dialect package, +// and also registers the mysql driver with apmsql. +package apmgormmysql + +import ( + _ "github.com/jinzhu/gorm/dialects/mysql" // import the mysql dialect + + _ "github.com/elastic/apm-agent-go/module/apmsql/mysql" // register mysql with apmsql +) diff --git a/module/apmgorm/dialects/postgres/init.go b/module/apmgorm/dialects/postgres/init.go new file mode 100644 index 000000000..f596a9557 --- /dev/null +++ b/module/apmgorm/dialects/postgres/init.go @@ -0,0 +1,9 @@ +// Package apmgormpostgres imports the gorm postgres dialect package, +// and also registers the lib/pq driver with apmsql. +package apmgormpostgres + +import ( + _ "github.com/jinzhu/gorm/dialects/postgres" // import the postgres dialect + + _ "github.com/elastic/apm-agent-go/module/apmsql/pq" // register lib/pq with apmsql +) diff --git a/module/apmgorm/dialects/sqlite/init.go b/module/apmgorm/dialects/sqlite/init.go new file mode 100644 index 000000000..310861d4c --- /dev/null +++ b/module/apmgorm/dialects/sqlite/init.go @@ -0,0 +1,9 @@ +// Package apmgormsqlite imports the gorm sqlite dialect package, +// and also registers the sqlite3 driver with apmsql. +package apmgormsqlite + +import ( + _ "github.com/jinzhu/gorm/dialects/sqlite" // import the sqlite dialect + + _ "github.com/elastic/apm-agent-go/module/apmsql/sqlite3" // register sqlite3 with apmsql +) diff --git a/module/apmgorm/doc.go b/module/apmgorm/doc.go new file mode 100644 index 000000000..9203be4ee --- /dev/null +++ b/module/apmgorm/doc.go @@ -0,0 +1,2 @@ +// Package apmgorm provides wrappers for tracing GORM operations. +package apmgorm diff --git a/module/apmgorm/open.go b/module/apmgorm/open.go new file mode 100644 index 000000000..90fe7c06e --- /dev/null +++ b/module/apmgorm/open.go @@ -0,0 +1,43 @@ +package apmgorm + +import ( + "github.com/jinzhu/gorm" + "github.com/pkg/errors" + + "github.com/elastic/apm-agent-go/module/apmsql" +) + +// Open returns a *gorm.DB for the given dialect and arguments. +// The returned *gorm.DB will have callbacks registered with +// RegisterCallbacks, such that CRUD operations will be reported +// as spans. +// +// Open accepts the following signatures: +// - a datasource name (i.e. the second argument to sql.Open) +// - a driver name and a datasource name +// - a *sql.DB, or some other type with the same interface +// +// If a driver and datasource name are supplied, and the appropriate +// apmgorm/dialects package has been imported (or the driver has +// otherwise been registered with apmsql), then the datasource name +// will be parsed for inclusion in the span context. +func Open(dialect string, args ...interface{}) (*gorm.DB, error) { + var driverName, dsn string + switch len(args) { + case 1: + switch arg0 := args[0].(type) { + case string: + driverName = dialect + dsn = arg0 + } + case 2: + driverName, _ = args[0].(string) + dsn, _ = args[1].(string) + } + db, err := gorm.Open(dialect, args...) + if err != nil { + return nil, errors.WithStack(err) + } + registerCallbacks(db, apmsql.DriverDSNParser(driverName)(dsn)) + return db, nil +} diff --git a/scripts/Dockerfile-testing b/scripts/Dockerfile-testing index 28711a13d..73b6f78c8 100644 --- a/scripts/Dockerfile-testing +++ b/scripts/Dockerfile-testing @@ -10,6 +10,10 @@ RUN go get -v github.com/google/go-cmp/cmp RUN go get -v github.com/gorilla/mux RUN go get -v github.com/grpc-ecosystem/go-grpc-middleware RUN go get -v github.com/grpc-ecosystem/go-grpc-middleware/recovery +RUN go get -v github.com/jinzhu/gorm +RUN go get -v github.com/jinzhu/gorm/dialects/mysql +RUN go get -v github.com/jinzhu/gorm/dialects/postgres +RUN go get -v github.com/jinzhu/gorm/dialects/sqlite RUN go get -v github.com/julienschmidt/httprouter RUN go get -v github.com/labstack/echo RUN go get -v github.com/lib/pq