diff --git a/contrib/go-chi/chi.v4/appsec.go b/contrib/go-chi/chi.v4/appsec.go new file mode 100644 index 0000000000..5ada4d6299 --- /dev/null +++ b/contrib/go-chi/chi.v4/appsec.go @@ -0,0 +1,32 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package chi + +import ( + "net/http" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" + + "github.com/go-chi/chi/v4" +) + +func withAppsec(next http.Handler, r *http.Request, span tracer.Span) http.Handler { + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + return httpsec.WrapHandler(next, span, nil) + } + var pathParams map[string]string + keys := rctx.URLParams.Keys + values := rctx.URLParams.Values + if len(keys) > 0 && len(keys) == len(values) { + pathParams = make(map[string]string, len(keys)) + for i, key := range keys { + pathParams[key] = values[i] + } + } + return httpsec.WrapHandler(next, span, pathParams) +} diff --git a/contrib/go-chi/chi.v4/chi.go b/contrib/go-chi/chi.v4/chi.go new file mode 100644 index 0000000000..af661d8bd4 --- /dev/null +++ b/contrib/go-chi/chi.v4/chi.go @@ -0,0 +1,89 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +// Package chi provides tracing functions for tracing the go-chi/chi/v4 package (https://github.com/go-chi/chi). +package chi // import "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-chi/chi.v4" + +import ( + "fmt" + "math" + "net/http" + "strconv" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" + + "github.com/go-chi/chi/v4" + "github.com/go-chi/chi/v4/middleware" +) + +// Middleware returns middleware that will trace incoming requests. +func Middleware(opts ...Option) func(next http.Handler) http.Handler { + cfg := new(config) + defaults(cfg) + for _, fn := range opts { + fn(cfg) + } + log.Debug("contrib/go-chi/chi.v4: Configuring Middleware: %#v", cfg) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cfg.ignoreRequest(r) { + next.ServeHTTP(w, r) + return + } + opts := []ddtrace.StartSpanOption{ + tracer.SpanType(ext.SpanTypeWeb), + tracer.ServiceName(cfg.serviceName), + tracer.Tag(ext.HTTPMethod, r.Method), + tracer.Tag(ext.HTTPURL, r.URL.Path), + tracer.Measured(), + } + if !math.IsNaN(cfg.analyticsRate) { + opts = append(opts, tracer.Tag(ext.EventSampleRate, cfg.analyticsRate)) + } + if spanctx, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)); err == nil { + opts = append(opts, tracer.ChildOf(spanctx)) + } + opts = append(opts, cfg.spanOpts...) + span, ctx := tracer.StartSpanFromContext(r.Context(), "http.request", opts...) + defer span.Finish() + + if appsec.Enabled() { + next = withAppsec(next, r, span) + // Note that the following response writer passed to the handler + // implements the `interface { Status() int }` expected by httpsec. + } + + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + + // pass the span through the request context and serve the request to the next middleware + next.ServeHTTP(ww, r.WithContext(ctx)) + + // set the resource name as we get it only once the handler is executed + resourceName := chi.RouteContext(r.Context()).RoutePattern() + if resourceName == "" { + resourceName = "unknown" + } + resourceName = r.Method + " " + resourceName + span.SetTag(ext.ResourceName, resourceName) + + // set the status code + status := ww.Status() + // 0 status means one has not yet been sent in which case net/http library will write StatusOK + if ww.Status() == 0 { + status = http.StatusOK + } + span.SetTag(ext.HTTPCode, strconv.Itoa(status)) + + if cfg.isStatusError(status) { + // mark 5xx server error + span.SetTag(ext.Error, fmt.Errorf("%d: %s", status, http.StatusText(status))) + } + }) + } +} diff --git a/contrib/go-chi/chi.v4/chi_test.go b/contrib/go-chi/chi.v4/chi_test.go new file mode 100644 index 0000000000..8e2448e5f3 --- /dev/null +++ b/contrib/go-chi/chi.v4/chi_test.go @@ -0,0 +1,382 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package chi + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" + + "github.com/go-chi/chi/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestChildSpan(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + router := chi.NewRouter() + router.Use(Middleware(WithServiceName("foobar"))) + router.Get("/user/{id}", func(w http.ResponseWriter, r *http.Request) { + _, ok := tracer.SpanFromContext(r.Context()) + assert.True(ok) + }) + + r := httptest.NewRequest("GET", "/user/123", nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, r) +} + +func TestTrace200(t *testing.T) { + assertDoRequest := func(assert *assert.Assertions, mt mocktracer.Tracer, router *chi.Mux) { + r := httptest.NewRequest("GET", "/user/123", nil) + w := httptest.NewRecorder() + + // do and verify the request + router.ServeHTTP(w, r) + response := w.Result() + assert.Equal(response.StatusCode, 200) + + // verify traces look good + spans := mt.FinishedSpans() + assert.Len(spans, 1) + if len(spans) < 1 { + t.Fatalf("no spans") + } + span := spans[0] + assert.Equal("http.request", span.OperationName()) + assert.Equal(ext.SpanTypeWeb, span.Tag(ext.SpanType)) + assert.Equal("foobar", span.Tag(ext.ServiceName)) + assert.Equal("GET /user/{id}", span.Tag(ext.ResourceName)) + assert.Equal("200", span.Tag(ext.HTTPCode)) + assert.Equal("GET", span.Tag(ext.HTTPMethod)) + assert.Equal("/user/123", span.Tag(ext.HTTPURL)) + } + + t.Run("response written", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + router := chi.NewRouter() + router.Use(Middleware(WithServiceName("foobar"))) + router.Get("/user/{id}", func(w http.ResponseWriter, r *http.Request) { + span, ok := tracer.SpanFromContext(r.Context()) + assert.True(ok) + assert.Equal(span.(mocktracer.Span).Tag(ext.ServiceName), "foobar") + id := chi.URLParam(r, "id") + _, err := w.Write([]byte(id)) + assert.NoError(err) + }) + assertDoRequest(assert, mt, router) + }) + + t.Run("no response written", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + router := chi.NewRouter() + router.Use(Middleware(WithServiceName("foobar"))) + router.Get("/user/{id}", func(w http.ResponseWriter, r *http.Request) { + span, ok := tracer.SpanFromContext(r.Context()) + assert.True(ok) + assert.Equal(span.(mocktracer.Span).Tag(ext.ServiceName), "foobar") + }) + assertDoRequest(assert, mt, router) + }) +} + +func TestError(t *testing.T) { + assertSpan := func(assert *assert.Assertions, spans []mocktracer.Span, code int) { + assert.Len(spans, 1) + if len(spans) < 1 { + t.Fatalf("no spans") + } + span := spans[0] + assert.Equal("http.request", span.OperationName()) + assert.Equal("foobar", span.Tag(ext.ServiceName)) + + assert.Equal(strconv.Itoa(code), span.Tag(ext.HTTPCode)) + + wantErr := fmt.Sprintf("%d: %s", code, http.StatusText(code)) + assert.Equal(wantErr, span.Tag(ext.Error).(error).Error()) + } + + t.Run("default", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + // setup + router := chi.NewRouter() + router.Use(Middleware(WithServiceName("foobar"))) + code := 500 + + // a handler with an error and make the requests + router.Get("/err", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, fmt.Sprintf("%d!", code), code) + }) + r := httptest.NewRequest("GET", "/err", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + response := w.Result() + assert.Equal(response.StatusCode, code) + + // verify the errors and status are correct + spans := mt.FinishedSpans() + assertSpan(assert, spans, code) + }) + + t.Run("custom", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + // setup + router := chi.NewRouter() + router.Use(Middleware( + WithServiceName("foobar"), + WithStatusCheck(func(statusCode int) bool { + return statusCode >= 400 + }), + )) + code := 404 + // a handler with an error and make the requests + router.Get("/err", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, fmt.Sprintf("%d!", code), code) + }) + r := httptest.NewRequest("GET", "/err", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + response := w.Result() + assert.Equal(response.StatusCode, code) + + // verify the errors and status are correct + spans := mt.FinishedSpans() + assertSpan(assert, spans, code) + }) +} + +func TestGetSpanNotInstrumented(t *testing.T) { + assert := assert.New(t) + router := chi.NewRouter() + router.Get("/ping", func(w http.ResponseWriter, r *http.Request) { + // Assert we don't have a span on the context. + _, ok := tracer.SpanFromContext(r.Context()) + assert.False(ok) + w.Write([]byte("ok")) + }) + r := httptest.NewRequest("GET", "/ping", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + response := w.Result() + assert.Equal(response.StatusCode, 200) +} + +func TestPropagation(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + r := httptest.NewRequest("GET", "/user/123", nil) + w := httptest.NewRecorder() + + pspan := tracer.StartSpan("test") + tracer.Inject(pspan.Context(), tracer.HTTPHeadersCarrier(r.Header)) + + router := chi.NewRouter() + router.Use(Middleware(WithServiceName("foobar"))) + router.Get("/user/{id}", func(w http.ResponseWriter, r *http.Request) { + span, ok := tracer.SpanFromContext(r.Context()) + assert.True(ok) + assert.Equal(span.(mocktracer.Span).ParentID(), pspan.(mocktracer.Span).SpanID()) + }) + + router.ServeHTTP(w, r) +} + +func TestAnalyticsSettings(t *testing.T) { + assertRate := func(t *testing.T, mt mocktracer.Tracer, rate interface{}, opts ...Option) { + router := chi.NewRouter() + router.Use(Middleware(opts...)) + router.Get("/user/{id}", func(w http.ResponseWriter, r *http.Request) { + _, ok := tracer.SpanFromContext(r.Context()) + assert.True(t, ok) + }) + + r := httptest.NewRequest("GET", "/user/123", nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, r) + spans := mt.FinishedSpans() + assert.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, rate, s.Tag(ext.EventSampleRate)) + } + + t.Run("defaults", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + assertRate(t, mt, nil) + }) + + t.Run("global", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + rate := globalconfig.AnalyticsRate() + defer globalconfig.SetAnalyticsRate(rate) + globalconfig.SetAnalyticsRate(0.4) + + assertRate(t, mt, 0.4) + }) + + t.Run("enabled", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + assertRate(t, mt, 1.0, WithAnalytics(true)) + }) + + t.Run("disabled", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + assertRate(t, mt, nil, WithAnalytics(false)) + }) + + t.Run("override", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + rate := globalconfig.AnalyticsRate() + defer globalconfig.SetAnalyticsRate(rate) + globalconfig.SetAnalyticsRate(0.4) + + assertRate(t, mt, 0.23, WithAnalyticsRate(0.23)) + }) +} + +func TestIgnoreRequest(t *testing.T) { + router := chi.NewRouter() + router.Use(Middleware( + WithIgnoreRequest(func(r *http.Request) bool { + return strings.HasPrefix(r.URL.Path, "/skip") + }), + )) + + router.Get("/ok", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + }) + + router.Get("/skip", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("skip")) + }) + + for path, shouldSkip := range map[string]bool{ + "/ok": false, + "/skip": true, + "/skipfoo": true, + } { + mt := mocktracer.Start() + defer mt.Reset() + + r := httptest.NewRequest("GET", "http://localhost"+path, nil) + router.ServeHTTP(httptest.NewRecorder(), r) + assert.Equal(t, shouldSkip, len(mt.FinishedSpans()) == 0) + } +} + +func TestAppSec(t *testing.T) { + appsec.Start() + defer appsec.Stop() + + if !appsec.Enabled() { + t.Skip("appsec disabled") + } + + // Start and trace an HTTP server with some testing routes + router := chi.NewRouter().With(Middleware()) + router.HandleFunc("/path0.0/{myPathParam0}/path0.1/{myPathParam1}/path0.2/{myPathParam2}/path0.3/*", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!\n")) + require.NoError(t, err) + }) + router.HandleFunc("/*", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!\n")) + require.NoError(t, err) + }) + + srv := httptest.NewServer(router) + defer srv.Close() + + // Test an LFI attack via path parameters + t.Run("request-uri", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + // Send an LFI attack (according to appsec rule id crs-930-100) + req, err := http.NewRequest("POST", srv.URL+"/../../../secret.txt", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + // Check that the server behaved as intended + require.Equal(t, http.StatusOK, res.StatusCode) + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + // The span should contain the security event + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + + // The first 301 redirection should contain the attack via the request uri + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "server.request.uri.raw")) + require.True(t, strings.Contains(event, "crs-930-100")) + }) + + // Test a security scanner attack via path parameters + t.Run("path-params", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + // Send a security scanner attack (according to appsec rule id crs-913-120) + req, err := http.NewRequest("POST", srv.URL+"/path0.0/param0/path0.1/param1/path0.2/appscan_fingerprint/path0.3/param3", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + // Check that the handler was properly called + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + require.Equal(t, http.StatusOK, res.StatusCode) + // The span should contain the security event + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "crs-913-120")) + require.True(t, strings.Contains(event, "myPathParam2")) + require.True(t, strings.Contains(event, "server.request.path_params")) + }) +} diff --git a/contrib/go-chi/chi.v4/example_test.go b/contrib/go-chi/chi.v4/example_test.go new file mode 100644 index 0000000000..2a5002a412 --- /dev/null +++ b/contrib/go-chi/chi.v4/example_test.go @@ -0,0 +1,55 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package chi_test + +import ( + "net/http" + + "github.com/go-chi/chi/v4" + + chitrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-chi/chi.v4" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" +) + +func handler(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello World!\n")) +} + +func Example() { + // Start the tracer + tracer.Start() + defer tracer.Stop() + + // Create a chi Router + router := chi.NewRouter() + + // Use the tracer middleware with the default service name "chi.router". + router.Use(chitrace.Middleware()) + + // Set up some endpoints. + router.Get("/", handler) + + // And start gathering request traces + http.ListenAndServe(":8080", router) +} + +func Example_withServiceName() { + // Start the tracer + tracer.Start() + defer tracer.Stop() + + // Create a chi Router + router := chi.NewRouter() + + // Use the tracer middleware with your desired service name. + router.Use(chitrace.Middleware(chitrace.WithServiceName("chi-server"))) + + // Set up some endpoints. + router.Get("/", handler) + + // And start gathering request traces + http.ListenAndServe(":8080", router) +} diff --git a/contrib/go-chi/chi.v4/option.go b/contrib/go-chi/chi.v4/option.go new file mode 100644 index 0000000000..6c19dd3cc7 --- /dev/null +++ b/contrib/go-chi/chi.v4/option.go @@ -0,0 +1,98 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package chi + +import ( + "math" + "net/http" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" + "gopkg.in/DataDog/dd-trace-go.v1/internal" + "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" +) + +type config struct { + serviceName string + spanOpts []ddtrace.StartSpanOption // additional span options to be applied + analyticsRate float64 + isStatusError func(statusCode int) bool + ignoreRequest func(r *http.Request) bool +} + +// Option represents an option that can be passed to NewRouter. +type Option func(*config) + +func defaults(cfg *config) { + cfg.serviceName = "chi.router" + if svc := globalconfig.ServiceName(); svc != "" { + cfg.serviceName = svc + } + if internal.BoolEnv("DD_TRACE_CHI_ANALYTICS_ENABLED", false) { + cfg.analyticsRate = 1.0 + } else { + cfg.analyticsRate = globalconfig.AnalyticsRate() + } + cfg.isStatusError = isServerError + cfg.ignoreRequest = func(_ *http.Request) bool { return false } +} + +// WithServiceName sets the given service name for the router. +func WithServiceName(name string) Option { + return func(cfg *config) { + cfg.serviceName = name + } +} + +// WithSpanOptions applies the given set of options to the spans started +// by the router. +func WithSpanOptions(opts ...ddtrace.StartSpanOption) Option { + return func(cfg *config) { + cfg.spanOpts = opts + } +} + +// WithAnalytics enables Trace Analytics for all started spans. +func WithAnalytics(on bool) Option { + return func(cfg *config) { + if on { + cfg.analyticsRate = 1.0 + } else { + cfg.analyticsRate = math.NaN() + } + } +} + +// WithAnalyticsRate sets the sampling rate for Trace Analytics events +// correlated to started spans. +func WithAnalyticsRate(rate float64) Option { + return func(cfg *config) { + if rate >= 0.0 && rate <= 1.0 { + cfg.analyticsRate = rate + } else { + cfg.analyticsRate = math.NaN() + } + } +} + +// WithStatusCheck specifies a function fn which reports whether the passed +// statusCode should be considered an error. +func WithStatusCheck(fn func(statusCode int) bool) Option { + return func(cfg *config) { + cfg.isStatusError = fn + } +} + +func isServerError(statusCode int) bool { + return statusCode >= 500 && statusCode < 600 +} + +// WithIgnoreRequest specifies a function to use for determining if the +// incoming HTTP request tracing should be skipped. +func WithIgnoreRequest(fn func(r *http.Request) bool) Option { + return func(cfg *config) { + cfg.ignoreRequest = fn + } +} diff --git a/contrib/go-chi/chi.v5/appsec.go b/contrib/go-chi/chi.v5/appsec.go new file mode 100644 index 0000000000..4a4166d096 --- /dev/null +++ b/contrib/go-chi/chi.v5/appsec.go @@ -0,0 +1,32 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package chi + +import ( + "net/http" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" + + "github.com/go-chi/chi/v5" +) + +func withAppsec(next http.Handler, r *http.Request, span tracer.Span) http.Handler { + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + return httpsec.WrapHandler(next, span, nil) + } + var pathParams map[string]string + keys := rctx.URLParams.Keys + values := rctx.URLParams.Values + if len(keys) > 0 && len(keys) == len(values) { + pathParams = make(map[string]string, len(keys)) + for i, key := range keys { + pathParams[key] = values[i] + } + } + return httpsec.WrapHandler(next, span, pathParams) +} diff --git a/contrib/go-chi/chi.v5/chi.go b/contrib/go-chi/chi.v5/chi.go index 477f8e9636..3c71a309d4 100644 --- a/contrib/go-chi/chi.v5/chi.go +++ b/contrib/go-chi/chi.v5/chi.go @@ -15,6 +15,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" "github.com/go-chi/chi/v5" @@ -52,6 +53,12 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler { span, ctx := tracer.StartSpanFromContext(r.Context(), "http.request", opts...) defer span.Finish() + if appsec.Enabled() { + next = withAppsec(next, r, span) + // Note that the following response writer passed to the handler + // implements the `interface { Status() int }` expected by httpsec. + } + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) // pass the span through the request context and serve the request to the next middleware diff --git a/contrib/go-chi/chi.v5/chi_test.go b/contrib/go-chi/chi.v5/chi_test.go index 3cd1c3294e..35ff5b3afb 100644 --- a/contrib/go-chi/chi.v5/chi_test.go +++ b/contrib/go-chi/chi.v5/chi_test.go @@ -7,6 +7,7 @@ package chi import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "strconv" @@ -16,10 +17,12 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestChildSpan(t *testing.T) { @@ -301,3 +304,79 @@ func TestIgnoreRequest(t *testing.T) { assert.Equal(t, shouldSkip, len(mt.FinishedSpans()) == 0) } } + +func TestAppSec(t *testing.T) { + appsec.Start() + defer appsec.Stop() + + if !appsec.Enabled() { + t.Skip("appsec disabled") + } + + // Start and trace an HTTP server with some testing routes + router := chi.NewRouter().With(Middleware()) + router.HandleFunc("/path0.0/{myPathParam0}/path0.1/{myPathParam1}/path0.2/{myPathParam2}/path0.3/*", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!\n")) + require.NoError(t, err) + }) + router.HandleFunc("/*", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!\n")) + require.NoError(t, err) + }) + + srv := httptest.NewServer(router) + defer srv.Close() + + // Test an LFI attack via path parameters + t.Run("request-uri", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + // Send an LFI attack (according to appsec rule id crs-930-100) + req, err := http.NewRequest("POST", srv.URL+"/../../../secret.txt", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + // Check that the server behaved as intended + require.Equal(t, http.StatusOK, res.StatusCode) + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + // The span should contain the security event + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + + // The first 301 redirection should contain the attack via the request uri + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "server.request.uri.raw")) + require.True(t, strings.Contains(event, "crs-930-100")) + }) + + // Test a security scanner attack via path parameters + t.Run("path-params", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + // Send a security scanner attack (according to appsec rule id crs-913-120) + req, err := http.NewRequest("POST", srv.URL+"/path0.0/param0/path0.1/param1/path0.2/appscan_fingerprint/path0.3/param3", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + // Check that the handler was properly called + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + require.Equal(t, http.StatusOK, res.StatusCode) + // The span should contain the security event + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "crs-913-120")) + require.True(t, strings.Contains(event, "myPathParam2")) + require.True(t, strings.Contains(event, "server.request.path_params")) + }) +} diff --git a/contrib/go-chi/chi/appsec.go b/contrib/go-chi/chi/appsec.go new file mode 100644 index 0000000000..0ff5149ba9 --- /dev/null +++ b/contrib/go-chi/chi/appsec.go @@ -0,0 +1,32 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package chi + +import ( + "net/http" + + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" + + "github.com/go-chi/chi" +) + +func withAppsec(next http.Handler, r *http.Request, span tracer.Span) http.Handler { + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + return httpsec.WrapHandler(next, span, nil) + } + var pathParams map[string]string + keys := rctx.URLParams.Keys + values := rctx.URLParams.Values + if len(keys) > 0 && len(keys) == len(values) { + pathParams = make(map[string]string, len(keys)) + for i, key := range keys { + pathParams[key] = values[i] + } + } + return httpsec.WrapHandler(next, span, pathParams) +} diff --git a/contrib/go-chi/chi/chi.go b/contrib/go-chi/chi/chi.go index c5866dd146..3f692b830d 100644 --- a/contrib/go-chi/chi/chi.go +++ b/contrib/go-chi/chi/chi.go @@ -15,6 +15,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" "github.com/go-chi/chi" @@ -52,6 +53,12 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler { span, ctx := tracer.StartSpanFromContext(r.Context(), "http.request", opts...) defer span.Finish() + if appsec.Enabled() { + next = withAppsec(next, r, span) + // Note that the following response writer passed to the handler + // implements the `interface { Status() int }` expected by httpsec. + } + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) // pass the span through the request context and serve the request to the next middleware diff --git a/contrib/go-chi/chi/chi_test.go b/contrib/go-chi/chi/chi_test.go index 2fcf68977e..a77cd8b6d5 100644 --- a/contrib/go-chi/chi/chi_test.go +++ b/contrib/go-chi/chi/chi_test.go @@ -7,6 +7,7 @@ package chi import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "strconv" @@ -16,10 +17,12 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" "github.com/go-chi/chi" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestChildSpan(t *testing.T) { @@ -301,3 +304,79 @@ func TestIgnoreRequest(t *testing.T) { assert.Equal(t, shouldSkip, len(mt.FinishedSpans()) == 0) } } + +func TestAppSec(t *testing.T) { + appsec.Start() + defer appsec.Stop() + + if !appsec.Enabled() { + t.Skip("appsec disabled") + } + + // Start and trace an HTTP server with some testing routes + router := chi.NewRouter().With(Middleware()) + router.HandleFunc("/path0.0/{myPathParam0}/path0.1/{myPathParam1}/path0.2/{myPathParam2}/path0.3/*", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!\n")) + require.NoError(t, err) + }) + router.HandleFunc("/*", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!\n")) + require.NoError(t, err) + }) + + srv := httptest.NewServer(router) + defer srv.Close() + + // Test an LFI attack via path parameters + t.Run("request-uri", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + // Send an LFI attack (according to appsec rule id crs-930-100) + req, err := http.NewRequest("POST", srv.URL+"/../../../secret.txt", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + // Check that the server behaved as intended + require.Equal(t, http.StatusOK, res.StatusCode) + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + // The span should contain the security event + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + + // The first 301 redirection should contain the attack via the request uri + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "server.request.uri.raw")) + require.True(t, strings.Contains(event, "crs-930-100")) + }) + + // Test a security scanner attack via path parameters + t.Run("path-params", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + // Send a security scanner attack (according to appsec rule id crs-913-120) + req, err := http.NewRequest("POST", srv.URL+"/path0.0/param0/path0.1/param1/path0.2/appscan_fingerprint/path0.3/param3", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + // Check that the handler was properly called + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + require.Equal(t, http.StatusOK, res.StatusCode) + // The span should contain the security event + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "crs-913-120")) + require.True(t, strings.Contains(event, "myPathParam2")) + require.True(t, strings.Contains(event, "server.request.path_params")) + }) +}