diff --git a/CHANGELOG.md b/CHANGELOG.md index c276337..a102419 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,14 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +### Added + +- Add AllowRoot option to prevent backward incompatible. (#13) + ### Changed - Upgrade to v0.20.0 of `go.opentelemetry.io/otel`. (#8) +- otelsql will not create root spans in absence of existing spans by default. (#13) ## [0.2.1] - 2021-03-28 diff --git a/README.md b/README.md index 9c1a06f..2a87c9e 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,9 @@ $ go get github.com/XSAM/otelsql | Ping | If set to true, will enable the creation of spans on Ping requests. | Implemented | Ping has context argument, but it might no needs to record. | | RowsNext | If set to true, will enable the creation of events on corresponding calls. This can result in many events. | Implemented | It provides more visibility. | | DisableErrSkip | If set to true, will suppress driver.ErrSkip errors in spans. | Implemented | ErrSkip error might annoying | +| AllowRoot | If set to true, will allow otelsql to create root spans in absence of existing spans or even context. | Implemented | It might helpful while debugging missing operations. | | RowsAffected, LastInsertID | If set to true, will enable the creation of spans on RowsAffected/LastInsertId calls. | Dropped | Don't know its use cases. We might add this later based on the users' feedback. | | QueryParams | If set to true, will enable recording of parameters used with parametrized queries. | Dropped | It will cause high cardinality values and security problems. | -| AllowRoot | If set to true, will allow ocsql to create root spans in absence of existing spans or even context. | Dropped | I don't think traces data have meaning without context. | ## Example diff --git a/config.go b/config.go index d07b4f8..20c42f9 100644 --- a/config.go +++ b/config.go @@ -63,6 +63,9 @@ type SpanOptions struct { // DisableErrSkip, if set to true, will suppress driver.ErrSkip errors in spans. DisableErrSkip bool + + // AllowRoot, if set to true, will create root spans in absence of existing spans or even context. + AllowRoot bool } type defaultSpanNameFormatter struct{} diff --git a/conn.go b/conn.go index 83da972..04ad4ab 100644 --- a/conn.go +++ b/conn.go @@ -53,7 +53,7 @@ func (c *otConn) Ping(ctx context.Context) (err error) { return driver.ErrSkip } - if c.cfg.SpanOptions.Ping { + if c.cfg.SpanOptions.Ping && (c.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid()) { var span trace.Span ctx, span = c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnPing, ""), trace.WithSpanKind(trace.SpanKindClient), @@ -85,14 +85,17 @@ func (c *otConn) ExecContext(ctx context.Context, query string, args []driver.Na return nil, driver.ErrSkip } - ctx, span := c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnExec, query), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes( - append(c.cfg.Attributes, - semconv.DBStatementKey.String(query), - )...), - ) - defer span.End() + var span trace.Span + if c.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + ctx, span = c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnExec, query), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes( + append(c.cfg.Attributes, + semconv.DBStatementKey.String(query), + )...), + ) + defer span.End() + } res, err = execer.ExecContext(ctx, query, args) if err != nil { @@ -116,14 +119,18 @@ func (c *otConn) QueryContext(ctx context.Context, query string, args []driver.N return nil, driver.ErrSkip } - queryCtx, span := c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnQuery, query), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes( - append(c.cfg.Attributes, - semconv.DBStatementKey.String(query), - )...), - ) - defer span.End() + var span trace.Span + queryCtx := ctx + if c.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + queryCtx, span = c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnQuery, query), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes( + append(c.cfg.Attributes, + semconv.DBStatementKey.String(query), + )...), + ) + defer span.End() + } rows, err = queryer.QueryContext(queryCtx, query, args) if err != nil { @@ -139,14 +146,17 @@ func (c *otConn) PrepareContext(ctx context.Context, query string) (stmt driver. return nil, driver.ErrSkip } - ctx, span := c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnPrepare, query), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes( - append(c.cfg.Attributes, - semconv.DBStatementKey.String(query), - )...), - ) - defer span.End() + var span trace.Span + if c.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + ctx, span = c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnPrepare, query), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes( + append(c.cfg.Attributes, + semconv.DBStatementKey.String(query), + )...), + ) + defer span.End() + } stmt, err = preparer.PrepareContext(ctx, query) if err != nil { @@ -162,11 +172,15 @@ func (c *otConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver. return nil, driver.ErrSkip } - beginTxCtx, span := c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnBeginTx, ""), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(c.cfg.Attributes...), - ) - defer span.End() + var span trace.Span + beginTxCtx := ctx + if c.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + beginTxCtx, span = c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnBeginTx, ""), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(c.cfg.Attributes...), + ) + defer span.End() + } tx, err = connBeginTx.BeginTx(beginTxCtx, opts) if err != nil { @@ -182,11 +196,14 @@ func (c *otConn) ResetSession(ctx context.Context) (err error) { return driver.ErrSkip } - ctx, span := c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnResetSession, ""), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(c.cfg.Attributes...), - ) - defer span.End() + var span trace.Span + if c.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + ctx, span = c.cfg.Tracer.Start(ctx, c.cfg.SpanNameFormatter.Format(ctx, MethodConnResetSession, ""), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(c.cfg.Attributes...), + ) + defer span.End() + } err = sessionResetter.ResetSession(ctx) if err != nil { diff --git a/conn_test.go b/conn_test.go index 1bcd042..11da2fc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -24,7 +24,6 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/oteltest" "go.opentelemetry.io/otel/semconv" "go.opentelemetry.io/otel/trace" @@ -140,14 +139,28 @@ var ( func TestOtConn_Ping(t *testing.T) { testCases := []struct { - name string - error bool - pingOption bool + name string + error bool + pingOption bool + allowRootOption bool + noParentSpan bool }{ { name: "ping enabled", pingOption: true, }, + { + name: "ping enabled with no parent span, allow root span", + pingOption: true, + allowRootOption: true, + noParentSpan: true, + }, + { + name: "ping enabled with no parent span, disallow root span", + pingOption: true, + allowRootOption: false, + noParentSpan: true, + }, { name: "ping enabled with error", pingOption: true, @@ -161,13 +174,12 @@ func TestOtConn_Ping(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) // New conn cfg := newMockConfig(tracer) cfg.SpanOptions.Ping = tc.pingOption + cfg.SpanOptions.AllowRoot = tc.allowRootOption mc := newMockConn(tc.error) otelConn := newConn(mc, cfg) @@ -178,24 +190,28 @@ func TestOtConn_Ping(t *testing.T) { require.NoError(t, err) } + spanList := sr.Completed() if tc.pingOption { - spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and one span created in Ping - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(cfg.Attributes), span.Attributes()) - assert.Equal(t, string(MethodConnPing), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) - - assert.Equal(t, 1, mc.pingCount) - assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(mc.pingCtx)) - if tc.error { - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - assert.Equal(t, codes.Unset, span.StatusCode()) + require.Equal(t, expectedSpanCount, len(spanList)) + + if tc.pingOption { + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: cfg.Attributes, + expectedMethod: MethodConnPing, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + ctx: mc.pingCtx, + }) + + assert.Equal(t, 1, mc.pingCount) + } + } else { + if !tc.noParentSpan { + require.Equal(t, 1, len(spanList)) } } }) @@ -204,8 +220,10 @@ func TestOtConn_Ping(t *testing.T) { func TestOtConn_ExecContext(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -214,53 +232,63 @@ func TestOtConn_ExecContext(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) // New conn cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption mc := newMockConn(tc.error) otelConn := newConn(mc, cfg) _, err := otelConn.ExecContext(ctx, "query", nil) + if tc.error { + require.Error(t, err) + } else { + require.NoError(t, err) + } spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and one span created in ExecContext - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, - cfg.Attributes...)), span.Attributes()) - assert.Equal(t, string(MethodConnExec), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, + cfg.Attributes...), + expectedMethod: MethodConnExec, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + ctx: mc.execContextCtx, + }) assert.Equal(t, 1, mc.execContextCount) - assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(mc.execContextCtx)) assert.Equal(t, "query", mc.execContextQuery) - - if tc.error { - require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) - } }) } } func TestOtConn_QueryContext(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -269,50 +297,62 @@ func TestOtConn_QueryContext(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) // New conn cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption mc := newMockConn(tc.error) otelConn := newConn(mc, cfg) rows, err := otelConn.QueryContext(ctx, "query", nil) + if tc.error { + require.Error(t, err) + } else { + require.NoError(t, err) + } spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and one span created in QueryContext - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, - cfg.Attributes...)), span.Attributes()) - assert.Equal(t, string(MethodConnQuery), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, + cfg.Attributes...), + expectedMethod: MethodConnQuery, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + ctx: mc.queryContextCtx, + }) assert.Equal(t, 1, mc.queryContextCount) - assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(mc.queryContextCtx)) assert.Equal(t, "query", mc.queryContextQuery) - if tc.error { - require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) - + if !tc.error { otelRows, ok := rows.(*otRows) require.True(t, ok) - assert.Equal(t, dummySpan.SpanContext().TraceID(), otelRows.span.SpanContext().TraceID()) - // Span that creates in newRows() is the child of the dummySpan - assert.Equal(t, dummySpan.SpanContext().SpanID(), otelRows.span.(*oteltest.Span).ParentSpanID()) + if dummySpan != nil { + assert.Equal(t, dummySpan.SpanContext().TraceID(), otelRows.span.SpanContext().TraceID()) + // Span that creates in newRows() is the child of the dummySpan + assert.Equal(t, dummySpan.SpanContext().SpanID(), otelRows.span.(*oteltest.Span).ParentSpanID()) + } } }) } @@ -320,8 +360,10 @@ func TestOtConn_QueryContext(t *testing.T) { func TestOtConn_PrepareContext(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -330,45 +372,55 @@ func TestOtConn_PrepareContext(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) // New conn cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption mc := newMockConn(tc.error) otelConn := newConn(mc, cfg) stmt, err := otelConn.PrepareContext(ctx, "query") + if tc.error { + require.Error(t, err) + } else { + require.NoError(t, err) + } spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and one span created in PrepareContext - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, - cfg.Attributes...)), span.Attributes()) - assert.Equal(t, string(MethodConnPrepare), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, + cfg.Attributes...), + expectedMethod: MethodConnPrepare, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + ctx: mc.prepareContextCtx, + }) assert.Equal(t, 1, mc.prepareContextCount) - assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(mc.prepareContextCtx)) assert.Equal(t, "query", mc.prepareContextQuery) - if tc.error { - require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) - + if !tc.error { otelStmt, ok := stmt.(*otStmt) require.True(t, ok) assert.Equal(t, "query", otelStmt.query) @@ -379,8 +431,10 @@ func TestOtConn_PrepareContext(t *testing.T) { func TestOtConn_BeginTx(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -389,46 +443,59 @@ func TestOtConn_BeginTx(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) // New conn cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption mc := newMockConn(tc.error) otelConn := newConn(mc, cfg) tx, err := otelConn.BeginTx(ctx, driver.TxOptions{}) + if tc.error { + require.Error(t, err) + } else { + require.NoError(t, err) + } spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and one span created in BeginTx - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(cfg.Attributes), span.Attributes()) - assert.Equal(t, string(MethodConnBeginTx), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: cfg.Attributes, + expectedMethod: MethodConnBeginTx, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + ctx: mc.beginTxCtx, + }) assert.Equal(t, 1, mc.beginTxCount) - assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(mc.beginTxCtx)) - - if tc.error { - require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) + if !tc.error { otelTx, ok := tx.(*otTx) require.True(t, ok) - assert.Equal(t, dummySpan.SpanContext(), trace.SpanContextFromContext(otelTx.ctx)) + + if dummySpan != nil { + assert.Equal(t, dummySpan.SpanContext(), trace.SpanContextFromContext(otelTx.ctx)) + } } }) } @@ -436,8 +503,10 @@ func TestOtConn_BeginTx(t *testing.T) { func TestOtConn_ResetSession(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -446,43 +515,51 @@ func TestOtConn_ResetSession(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) // New conn cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption mc := newMockConn(tc.error) otelConn := newConn(mc, cfg) err := otelConn.ResetSession(ctx) - - spanList := sr.Completed() - // One dummy span and one span created in ResetSession - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(cfg.Attributes), span.Attributes()) - assert.Equal(t, string(MethodConnResetSession), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) - - assert.Equal(t, 1, mc.resetSessionCount) - assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(mc.resetSessionCtx)) - if tc.error { require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) } else { require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) } + + spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) + // One dummy span and one span created in ResetSession + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: cfg.Attributes, + expectedMethod: MethodConnResetSession, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + ctx: mc.resetSessionCtx, + }) + + assert.Equal(t, 1, mc.resetSessionCount) }) } } diff --git a/rows.go b/rows.go index d0f8dd1..e59490d 100644 --- a/rows.go +++ b/rows.go @@ -39,10 +39,13 @@ type otRows struct { } func newRows(ctx context.Context, rows driver.Rows, cfg config) *otRows { - _, span := cfg.Tracer.Start(ctx, cfg.SpanNameFormatter.Format(ctx, MethodRows, ""), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(cfg.Attributes...), - ) + var span trace.Span + if cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + _, span = cfg.Tracer.Start(ctx, cfg.SpanNameFormatter.Format(ctx, MethodRows, ""), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(cfg.Attributes...), + ) + } return &otRows{ Rows: rows, @@ -62,7 +65,7 @@ func (r otRows) HasNextResultSet() bool { return false } -// NextResultsSet calls the implements the driver.RowsNextResultSet for otRows. +// NextResultSet calls the implements the driver.RowsNextResultSet for otRows. // It returns the the underlying result of NextResultSet from the otRows.parent // if the parent implements driver.RowsNextResultSet. func (r otRows) NextResultSet() error { @@ -118,7 +121,11 @@ func (r otRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok } func (r otRows) Close() (err error) { - defer r.span.End() + defer func() { + if r.span != nil { + r.span.End() + } + }() err = r.Rows.Close() if err != nil { @@ -128,7 +135,7 @@ func (r otRows) Close() (err error) { } func (r otRows) Next(dest []driver.Value) (err error) { - if r.cfg.SpanOptions.RowsNext { + if r.cfg.SpanOptions.RowsNext && r.span != nil { r.span.AddEvent(string(EventRowsNext)) } diff --git a/rows_test.go b/rows_test.go index 5e82756..1e600fc 100644 --- a/rows_test.go +++ b/rows_test.go @@ -15,7 +15,6 @@ package otelsql import ( - "context" "database/sql/driver" "errors" "testing" @@ -24,7 +23,6 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) type mockRows struct { @@ -80,21 +78,20 @@ func TestOtRows_Close(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") + ctx, sr, tracer, _ := prepareTraces(false) mr := newMockRows(tc.error) cfg := newMockConfig(tracer) // New rows - rows := newRows(context.Background(), mr, cfg) + rows := newRows(ctx, mr, cfg) // Close err := rows.Close() spanList := sr.Completed() // A span created in newRows() - require.Equal(t, 1, len(spanList)) - span := spanList[0] + require.Equal(t, 2, len(spanList)) + span := spanList[1] assert.True(t, span.Ended()) assert.Equal(t, 1, mr.closeCount) @@ -132,22 +129,21 @@ func TestOtRows_Next(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") + ctx, sr, tracer, _ := prepareTraces(false) mr := newMockRows(tc.error) cfg := newMockConfig(tracer) cfg.SpanOptions.RowsNext = tc.rowsNextOption // New rows - rows := newRows(context.Background(), mr, cfg) + rows := newRows(ctx, mr, cfg) // Next err := rows.Next([]driver.Value{"test"}) spanList := sr.Started() // A span created in newRows() - require.Equal(t, 1, len(spanList)) - span := spanList[0] + require.Equal(t, 2, len(spanList)) + span := spanList[1] assert.False(t, span.Ended()) assert.Equal(t, 1, mr.nextCount) @@ -171,26 +167,53 @@ func TestOtRows_Next(t *testing.T) { } func TestNewRows(t *testing.T) { - // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) - - mr := newMockRows(false) - cfg := newMockConfig(tracer) - - // New rows - rows := newRows(ctx, mr, cfg) - - spanList := sr.Started() - // One dummy span and one span created in newRows() - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.False(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(cfg.Attributes), span.Attributes()) - assert.Equal(t, string(MethodRows), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) - assert.Equal(t, mr, rows.Rows) + testCases := []struct { + name string + allowRootOption bool + noParentSpan bool + }{ + { + name: "default config", + }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Prepare traces + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) + + mr := newMockRows(false) + cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption + + // New rows + rows := newRows(ctx, mr, cfg) + + spanList := sr.Started() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) + // One dummy span and one span created in newRows() + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: false, + expectedAttributes: cfg.Attributes, + expectedMethod: MethodRows, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + spanNotEnded: true, + }) + + assert.Equal(t, mr, rows.Rows) + }) + } } diff --git a/stmt.go b/stmt.go index 01ef193..3327c73 100644 --- a/stmt.go +++ b/stmt.go @@ -50,14 +50,17 @@ func (s *otStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res return nil, driver.ErrSkip } - ctx, span := s.cfg.Tracer.Start(ctx, s.cfg.SpanNameFormatter.Format(ctx, MethodStmtExec, s.query), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes( - append(s.cfg.Attributes, - semconv.DBStatementKey.String(s.query), - )...), - ) - defer span.End() + var span trace.Span + if s.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + ctx, span = s.cfg.Tracer.Start(ctx, s.cfg.SpanNameFormatter.Format(ctx, MethodStmtExec, s.query), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes( + append(s.cfg.Attributes, + semconv.DBStatementKey.String(s.query), + )...), + ) + defer span.End() + } result, err = execer.ExecContext(ctx, args) if err != nil { @@ -73,14 +76,18 @@ func (s *otStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (ro return nil, driver.ErrSkip } - queryCtx, span := s.cfg.Tracer.Start(ctx, s.cfg.SpanNameFormatter.Format(ctx, MethodStmtQuery, s.query), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes( - append(s.cfg.Attributes, - semconv.DBStatementKey.String(s.query), - )...), - ) - defer span.End() + var span trace.Span + queryCtx := ctx + if s.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(ctx).IsValid() { + queryCtx, span = s.cfg.Tracer.Start(ctx, s.cfg.SpanNameFormatter.Format(ctx, MethodStmtQuery, s.query), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes( + append(s.cfg.Attributes, + semconv.DBStatementKey.String(s.query), + )...), + ) + defer span.End() + } rows, err = query.QueryContext(queryCtx, args) if err != nil { diff --git a/stmt_test.go b/stmt_test.go index 5962c37..5e054f9 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -24,9 +24,7 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/semconv" - "go.opentelemetry.io/otel/trace" ) type mockStmt struct { @@ -78,8 +76,10 @@ var ( func TestOtStmt_ExecContext(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -88,51 +88,62 @@ func TestOtStmt_ExecContext(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) ms := newMockStmt(tc.error) // New stmt cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption stmt := newStmt(ms, cfg, "query") // Exec _, err := stmt.ExecContext(ctx, []driver.NamedValue{{Name: "test"}}) + if tc.error { + require.Error(t, err) + } else { + require.NoError(t, err) + } spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and a span created in tx - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, - cfg.Attributes...)), span.Attributes()) - assert.Equal(t, string(MethodStmtExec), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, + cfg.Attributes...), + expectedMethod: MethodStmtExec, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + }) assert.Equal(t, 1, ms.execCount) assert.Equal(t, []driver.NamedValue{{Name: "test"}}, ms.ExecContextArgs) - if tc.error { - require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) - } }) } } func TestOtStmt_QueryContext(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -141,42 +152,53 @@ func TestOtStmt_QueryContext(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) ms := newMockStmt(tc.error) // New stmt cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption stmt := newStmt(ms, cfg, "query") // Query rows, err := stmt.QueryContext(ctx, []driver.NamedValue{{Name: "test"}}) + if tc.error { + require.Error(t, err) + } else { + require.NoError(t, err) + } spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) // One dummy span and a span created in tx - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, - cfg.Attributes...)), span.Attributes()) - assert.Equal(t, string(MethodStmtQuery), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: append([]attribute.KeyValue{semconv.DBStatementKey.String("query")}, + cfg.Attributes...), + expectedMethod: MethodStmtQuery, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + }) assert.Equal(t, 1, ms.queryCount) assert.Equal(t, []driver.NamedValue{{Name: "test"}}, ms.queryContextArgs) - if tc.error { - require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) - } else { - require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) + if !tc.error { assert.IsType(t, &otRows{}, rows) } }) diff --git a/tx.go b/tx.go index 0504d4f..f48876b 100644 --- a/tx.go +++ b/tx.go @@ -38,11 +38,14 @@ func newTx(ctx context.Context, tx driver.Tx, cfg config) *otTx { } func (t *otTx) Commit() (err error) { - _, span := t.cfg.Tracer.Start(t.ctx, t.cfg.SpanNameFormatter.Format(t.ctx, MethodTxCommit, ""), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(t.cfg.Attributes...), - ) - defer span.End() + var span trace.Span + if t.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(t.ctx).IsValid() { + _, span = t.cfg.Tracer.Start(t.ctx, t.cfg.SpanNameFormatter.Format(t.ctx, MethodTxCommit, ""), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(t.cfg.Attributes...), + ) + defer span.End() + } err = t.tx.Commit() if err != nil { @@ -53,11 +56,14 @@ func (t *otTx) Commit() (err error) { } func (t *otTx) Rollback() (err error) { - _, span := t.cfg.Tracer.Start(t.ctx, t.cfg.SpanNameFormatter.Format(t.ctx, MethodTxRollback, ""), - trace.WithSpanKind(trace.SpanKindClient), - trace.WithAttributes(t.cfg.Attributes...), - ) - defer span.End() + var span trace.Span + if t.cfg.SpanOptions.AllowRoot || trace.SpanContextFromContext(t.ctx).IsValid() { + _, span = t.cfg.Tracer.Start(t.ctx, t.cfg.SpanNameFormatter.Format(t.ctx, MethodTxRollback, ""), + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(t.cfg.Attributes...), + ) + defer span.End() + } err = t.tx.Rollback() if err != nil { diff --git a/tx_test.go b/tx_test.go index f1364ac..84f8204 100644 --- a/tx_test.go +++ b/tx_test.go @@ -15,7 +15,6 @@ package otelsql import ( - "context" "database/sql/driver" "errors" "testing" @@ -24,8 +23,6 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" ) type mockTx struct { @@ -61,8 +58,10 @@ var defaultattribute = attribute.Key("test").String("foo") func TestOtTx_Commit(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -71,49 +70,60 @@ func TestOtTx_Commit(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) mt := newMockTx(tc.error) // New tx cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption tx := newTx(ctx, mt, cfg) // Commit err := tx.Commit() - - spanList := sr.Completed() - // One dummy span and one span created in tx - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(cfg.Attributes), span.Attributes()) - assert.Equal(t, string(MethodTxCommit), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) - - assert.Equal(t, 1, mt.commitCount) if tc.error { require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) } else { require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) } + + spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) + // One dummy span and one span created in tx + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: cfg.Attributes, + expectedMethod: MethodTxCommit, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + }) + + assert.Equal(t, 1, mt.commitCount) }) } } func TestOtTx_Rollback(t *testing.T) { testCases := []struct { - name string - error bool + name string + error bool + allowRootOption bool + noParentSpan bool }{ { name: "no error", @@ -122,41 +132,51 @@ func TestOtTx_Rollback(t *testing.T) { name: "with error", error: true, }, + { + name: "no parent span, disallow root span", + noParentSpan: true, + }, + { + name: "no parent span, allow root span", + noParentSpan: true, + allowRootOption: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Prepare traces - sr, provider := newTracerProvider() - tracer := provider.Tracer("test") - ctx, dummySpan := createDummySpan(context.Background(), tracer) + ctx, sr, tracer, dummySpan := prepareTraces(tc.noParentSpan) mt := newMockTx(tc.error) // New tx cfg := newMockConfig(tracer) + cfg.SpanOptions.AllowRoot = tc.allowRootOption tx := newTx(ctx, mt, cfg) + // Rollback err := tx.Rollback() - - spanList := sr.Completed() - // One dummy span and a span created in tx - require.Equal(t, 2, len(spanList)) - span := spanList[1] - assert.True(t, span.Ended()) - assert.Equal(t, trace.SpanKindClient, span.SpanKind()) - assert.Equal(t, attributesListToMap(cfg.Attributes), span.Attributes()) - assert.Equal(t, string(MethodTxRollback), span.Name()) - assert.Equal(t, dummySpan.SpanContext().TraceID(), span.SpanContext().TraceID()) - assert.Equal(t, dummySpan.SpanContext().SpanID(), span.ParentSpanID()) - assert.Equal(t, 1, mt.rollbackCount) - if tc.error { require.Error(t, err) - assert.Equal(t, codes.Error, span.StatusCode()) } else { require.NoError(t, err) - assert.Equal(t, codes.Unset, span.StatusCode()) } + + spanList := sr.Completed() + expectedSpanCount := getExpectedSpanCount(tc.allowRootOption, tc.noParentSpan) + // One dummy span and a span created in tx + require.Equal(t, expectedSpanCount, len(spanList)) + + assertSpanList(t, spanList, spanAssertionParameter{ + parentSpan: dummySpan, + error: tc.error, + expectedAttributes: cfg.Attributes, + expectedMethod: MethodTxRollback, + allowRootOption: tc.allowRootOption, + noParentSpan: tc.noParentSpan, + }) + + assert.Equal(t, 1, mt.rollbackCount) }) } } diff --git a/utils.go b/utils.go index 6f01e4e..dea329d 100644 --- a/utils.go +++ b/utils.go @@ -22,6 +22,10 @@ import ( ) func recordSpanError(span trace.Span, opts SpanOptions, err error) { + if span == nil { + return + } + switch err { case nil: return diff --git a/utils_test.go b/utils_test.go index 1c21706..8055ffc 100644 --- a/utils_test.go +++ b/utils_test.go @@ -34,6 +34,7 @@ func TestRecordSpanError(t *testing.T) { opts SpanOptions err error expectedError bool + nilSpan bool }{ { name: "no error", @@ -62,17 +63,27 @@ func TestRecordSpanError(t *testing.T) { opts: SpanOptions{DisableErrSkip: true}, expectedError: false, }, + { + name: "nil span", + err: nil, + nilSpan: true, + expectedError: false, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var span oteltest.Span - recordSpanError(&span, tc.opts, tc.err) + if !tc.nilSpan { + var span oteltest.Span + recordSpanError(&span, tc.opts, tc.err) - if tc.expectedError { - assert.Equal(t, codes.Error, span.StatusCode()) + if tc.expectedError { + assert.Equal(t, codes.Error, span.StatusCode()) + } else { + assert.Equal(t, codes.Unset, span.StatusCode()) + } } else { - assert.Equal(t, codes.Unset, span.StatusCode()) + recordSpanError(nil, tc.opts, tc.err) } }) } @@ -108,3 +119,71 @@ func attributesListToMap(attributes []attribute.KeyValue) map[attribute.Key]attr } return attributesMap } + +type spanAssertionParameter struct { + parentSpan trace.Span + error bool + expectedAttributes []attribute.KeyValue + expectedMethod Method + allowRootOption bool + noParentSpan bool + ctx context.Context + spanNotEnded bool +} + +func assertSpanList(t *testing.T, spanList []*oteltest.Span, parameter spanAssertionParameter) { + var span *oteltest.Span + if !parameter.noParentSpan { + span = spanList[1] + } else if parameter.allowRootOption { + span = spanList[0] + } + + if span != nil { + if parameter.spanNotEnded { + assert.False(t, span.Ended()) + } else { + assert.True(t, span.Ended()) + } + assert.Equal(t, trace.SpanKindClient, span.SpanKind()) + assert.Equal(t, attributesListToMap(parameter.expectedAttributes), span.Attributes()) + assert.Equal(t, string(parameter.expectedMethod), span.Name()) + if parameter.parentSpan != nil { + assert.Equal(t, parameter.parentSpan.SpanContext().TraceID(), span.SpanContext().TraceID()) + assert.Equal(t, parameter.parentSpan.SpanContext().SpanID(), span.ParentSpanID()) + } + + if parameter.error { + assert.Equal(t, codes.Error, span.StatusCode()) + } else { + assert.Equal(t, codes.Unset, span.StatusCode()) + } + + if parameter.ctx != nil { + assert.Equal(t, span.SpanContext(), trace.SpanContextFromContext(parameter.ctx)) + } + } +} + +func getExpectedSpanCount(allowRootOption bool, noParentSpan bool) int { + var expectedSpanCount int + if allowRootOption { + expectedSpanCount++ + } + if !noParentSpan { + expectedSpanCount = 2 + } + return expectedSpanCount +} + +func prepareTraces(noParentSpan bool) (context.Context, *oteltest.SpanRecorder, trace.Tracer, trace.Span) { + sr, provider := newTracerProvider() + tracer := provider.Tracer("test") + + var dummySpan trace.Span + ctx := context.Background() + if !noParentSpan { + ctx, dummySpan = createDummySpan(context.Background(), tracer) + } + return ctx, sr, tracer, dummySpan +}