diff --git a/contrib/database/sql/internal/mockdriver.go b/contrib/database/sql/internal/mockdriver.go index 1b0fe6b14e..fb7d43d39b 100644 --- a/contrib/database/sql/internal/mockdriver.go +++ b/contrib/database/sql/internal/mockdriver.go @@ -15,15 +15,10 @@ import ( type MockDriver struct { Prepared []string Executed []string - // Hook is an optional function to run during a DB operation - Hook func() } // Open implements the Conn interface func (d *MockDriver) Open(_ string) (driver.Conn, error) { - if d.Hook != nil { - d.Hook() - } return &mockConn{driver: d}, nil } @@ -34,27 +29,18 @@ type mockConn struct { // Prepare implements the driver.Conn interface func (m *mockConn) Prepare(query string) (driver.Stmt, error) { m.driver.Prepared = append(m.driver.Prepared, query) - if m.driver.Hook != nil { - m.driver.Hook() - } return &mockStmt{stmt: query, driver: m.driver}, nil } // QueryContext implements the QueryerContext interface func (m *mockConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) { m.driver.Executed = append(m.driver.Executed, query) - if m.driver.Hook != nil { - m.driver.Hook() - } return &rows{}, nil } // ExecContext implements the ExecerContext interface func (m *mockConn) ExecContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Result, error) { m.driver.Executed = append(m.driver.Executed, query) - if m.driver.Hook != nil { - m.driver.Hook() - } return &mockResult{}, nil } @@ -65,9 +51,6 @@ func (m *mockConn) Close() (err error) { // Begin implements the Conn interface func (m *mockConn) Begin() (driver.Tx, error) { - if m.driver.Hook != nil { - m.driver.Hook() - } return &mockTx{driver: m.driver}, nil } @@ -94,9 +77,6 @@ type mockTx struct { // Commit implements the Tx interface func (t *mockTx) Commit() error { - if t.driver.Hook != nil { - t.driver.Hook() - } return nil } @@ -135,18 +115,12 @@ func (s *mockStmt) Query(_ []driver.Value) (driver.Rows, error) { // ExecContext implements the StmtExecContext interface func (s *mockStmt) ExecContext(_ context.Context, _ []driver.NamedValue) (driver.Result, error) { s.driver.Executed = append(s.driver.Executed, s.stmt) - if s.driver.Hook != nil { - s.driver.Hook() - } return &mockResult{}, nil } // QueryContext implements the StmtQueryContext interface func (s *mockStmt) QueryContext(_ context.Context, _ []driver.NamedValue) (driver.Rows, error) { s.driver.Executed = append(s.driver.Executed, s.stmt) - if s.driver.Hook != nil { - s.driver.Hook() - } return &rows{}, nil } diff --git a/internal/exectracetest/go.mod b/internal/exectracetest/go.mod index 4abf8ad591..e5573375e9 100644 --- a/internal/exectracetest/go.mod +++ b/internal/exectracetest/go.mod @@ -6,6 +6,7 @@ toolchain go1.21.0 require ( github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b + github.com/mattn/go-sqlite3 v1.14.18 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 gopkg.in/DataDog/dd-trace-go.v1 v1.64.0 ) diff --git a/internal/exectracetest/go.sum b/internal/exectracetest/go.sum index 77f544a618..1f817550c2 100644 --- a/internal/exectracetest/go.sum +++ b/internal/exectracetest/go.sum @@ -24,6 +24,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.11.0 h1:9rHa233rhdOyrz2GcP9NM+gi2psgJZ4GWDpL/7ND8HI= +github.com/denisenkom/go-mssqldb v0.11.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -36,6 +38,10 @@ github.com/ebitengine/purego v0.6.0-alpha.5 h1:EYID3JOAdmQ4SNZYJHu9V6IqOeRQDBYxq github.com/ebitengine/purego v0.6.0-alpha.5/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -55,8 +61,12 @@ github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= +github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= diff --git a/internal/exectracetest/sql_test.go b/internal/exectracetest/sql_test.go new file mode 100644 index 0000000000..3a9f2a3f8a --- /dev/null +++ b/internal/exectracetest/sql_test.go @@ -0,0 +1,158 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2024 Datadog, Inc. + +package exectracetest + +import ( + "context" + "net/http" + "runtime/trace" + "slices" + "testing" + "time" + + "github.com/mattn/go-sqlite3" + exptrace "golang.org/x/exp/trace" + + sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/httpmem" +) + +func must[T any](val T, err error) T { + if err != nil { + panic(err) + } + return val +} + +func TestExecutionTraceAnnotations(t *testing.T) { + if trace.IsEnabled() { + t.Skip("execution tracing is already enabled") + } + + // In-memory server & client which discards everything, to avoid + // slowness from unnecessary network I/O + s, c := httpmem.ServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer s.Close() + tracer.Start(tracer.WithHTTPClient(c), tracer.WithLogStartup(false)) + defer tracer.Stop() + + // sleepDuration is the amount of time our mock DB operations should + // take. We are going to assert that the execution trace tasks + // corresponding with the duration are at least as long as this. In + // reality they could be longer than this due to slow CI, scheduling + // jitter, etc., but we know that they should be at least this long. + const sleepDuration = 10 * time.Millisecond + sleep := func() int { + time.Sleep(sleepDuration) + return 0 + } + sqltrace.Register("sqlite3_extended", + &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + time.Sleep(sleepDuration) + return conn.RegisterFunc("sleep", sleep, true) + }, + }, + ) + + db := must(sqltrace.Open("sqlite3_extended", ":memory:")) + + _, events := collectTestData(t, func() { + span, ctx := tracer.StartSpanFromContext(context.Background(), "parent") + must(db.ExecContext(ctx, "select sleep()")) + conn := must(db.Conn(ctx)) + rows := must(conn.QueryContext(ctx, "select 1")) + rows.Close() + stmt := must(conn.PrepareContext(ctx, "select sleep()")) + must(stmt.Exec()) + rows = must(stmt.Query()) + // NB: the sleep() is only actually evaluated when + // we iterate over the rows, and that part isn't traced, + // so we won't make assertions about the duration of the + // "Query" task later + rows.Close() + tx := must(conn.BeginTx(ctx, nil)) + must(tx.ExecContext(ctx, "select sleep()")) + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + + span.Finish() + }) + + tasks := getTasks(events) + + expectedParentChildTasks := []string{"Connect", "Exec", "Query", "Prepare", "Begin", "Exec"} + var foundParent, foundPrepared bool + for id, task := range tasks { + t.Logf("task %d: %+v", id, task) + switch task.name { + case "parent": + foundParent = true + var got []string + for _, child := range tasks[id].children { + got = append(got, tasks[child].name) + } + if !slices.Equal(expectedParentChildTasks, got) { + t.Errorf( + "did not find expected child tasks of parent: want %s, got %s", + expectedParentChildTasks, got, + ) + } + case "Prepare": + foundPrepared = true + case "Connect", "Exec": + if d := task.Duration(); d < sleepDuration { + t.Errorf("task %s: duration %v less than minimum %v", task.name, d, sleepDuration) + } + } + } + if !foundParent { + t.Error("did not find parent task") + } + if !foundPrepared { + t.Error("did not find prepared statement task") + } +} + +type traceTask struct { + name string + start, end exptrace.Time + parent exptrace.TaskID + children []exptrace.TaskID +} + +func (t *traceTask) Duration() time.Duration { return time.Duration(t.end - t.start) } + +func getTasks(events []exptrace.Event) map[exptrace.TaskID]*traceTask { + tasks := make(map[exptrace.TaskID]*traceTask) + for _, ev := range events { + switch ev.Kind() { + case exptrace.EventTaskBegin: + task := ev.Task() + parent := task.Parent + if t, ok := tasks[parent]; ok { + t.children = append(t.children, task.ID) + } + tasks[task.ID] = &traceTask{ + name: task.Type, + parent: parent, + start: ev.Time(), + } + case exptrace.EventTaskEnd: + task := ev.Task() + if t, ok := tasks[task.ID]; ok { + t.end = ev.Time() + } + default: + } + } + return tasks +}