diff --git a/backend/adapter_utils.go b/backend/adapter_utils.go deleted file mode 100644 index 269e3f128..000000000 --- a/backend/adapter_utils.go +++ /dev/null @@ -1,173 +0,0 @@ -package backend - -import ( - "context" - "errors" - "fmt" - "net/http" - "sync" - "time" - - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana-plugin-sdk-go/backend/tracing" - "github.com/prometheus/client_golang/prometheus" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -type handlerWrapperFunc func(ctx context.Context) (RequestStatus, error) - -func setupContext(ctx context.Context, endpoint Endpoint) context.Context { - ctx = WithEndpoint(ctx, endpoint) - ctx = propagateTenantIDIfPresent(ctx) - - return ctx -} - -func wrapHandler(ctx context.Context, pluginCtx PluginContext, next handlerWrapperFunc) error { - ctx = setupHandlerContext(ctx, pluginCtx) - wrapper := errorWrapper(next) - wrapper = logWrapper(wrapper) - wrapper = metricWrapper(wrapper) - wrapper = tracingWrapper(wrapper) - _, err := wrapper(ctx) - return err -} - -func setupHandlerContext(ctx context.Context, pluginCtx PluginContext) context.Context { - ctx = initErrorSource(ctx) - ctx = WithGrafanaConfig(ctx, pluginCtx.GrafanaConfig) - ctx = WithPluginContext(ctx, pluginCtx) - ctx = WithUser(ctx, pluginCtx.User) - ctx = withContextualLogAttributes(ctx, pluginCtx) - ctx = WithUserAgent(ctx, pluginCtx.UserAgent) - return ctx -} - -func errorWrapper(next handlerWrapperFunc) handlerWrapperFunc { - return func(ctx context.Context) (RequestStatus, error) { - status, err := next(ctx) - - if err != nil && IsDownstreamError(err) { - if innerErr := WithDownstreamErrorSource(ctx); innerErr != nil { - return RequestStatusError, fmt.Errorf("failed to set downstream status source: %w", errors.Join(innerErr, err)) - } - } - - return status, err - } -} - -var pluginRequestCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: "grafana_plugin", - Name: "request_total", - Help: "The total amount of plugin requests", -}, []string{"endpoint", "status", "status_source"}) - -var once = sync.Once{} - -func metricWrapper(next handlerWrapperFunc) handlerWrapperFunc { - once.Do(func() { - prometheus.MustRegister(pluginRequestCounter) - }) - - return func(ctx context.Context) (RequestStatus, error) { - endpoint := EndpointFromContext(ctx) - status, err := next(ctx) - - pluginRequestCounter.WithLabelValues(endpoint.String(), status.String(), string(ErrorSourceFromContext(ctx))).Inc() - - return status, err - } -} - -func tracingWrapper(next handlerWrapperFunc) handlerWrapperFunc { - return func(ctx context.Context) (RequestStatus, error) { - endpoint := EndpointFromContext(ctx) - pluginCtx := PluginConfigFromContext(ctx) - ctx, span := tracing.DefaultTracer().Start(ctx, fmt.Sprintf("sdk.%s", endpoint), trace.WithAttributes( - attribute.String("plugin_id", pluginCtx.PluginID), - attribute.Int64("org_id", pluginCtx.OrgID), - )) - defer span.End() - - if pluginCtx.DataSourceInstanceSettings != nil { - span.SetAttributes( - attribute.String("datasource_name", pluginCtx.DataSourceInstanceSettings.Name), - attribute.String("datasource_uid", pluginCtx.DataSourceInstanceSettings.UID), - ) - } - - if u := pluginCtx.User; u != nil { - span.SetAttributes(attribute.String("user", pluginCtx.User.Name)) - } - - status, err := next(ctx) - - span.SetAttributes( - attribute.String("request_status", status.String()), - attribute.String("status_source", string(ErrorSourceFromContext(ctx))), - ) - - if err != nil { - return status, tracing.Error(span, err) - } - - return status, err - } -} - -func logWrapper(next handlerWrapperFunc) handlerWrapperFunc { - return func(ctx context.Context) (RequestStatus, error) { - start := time.Now() - - ctxLogger := Logger.FromContext(ctx) - logFunc := ctxLogger.Debug - logFunc("Plugin Request Started") - - status, err := next(ctx) - - logParams := []any{ - "status", status.String(), - "duration", time.Since(start).String(), - } - - if err != nil { - logParams = append(logParams, "error", err) - } - - logParams = append(logParams, "statusSource", string(ErrorSourceFromContext(ctx))) - - if status > RequestStatusCancelled { - logFunc = ctxLogger.Error - } - - logFunc("Plugin Request Completed", logParams...) - - return status, err - } -} - -func withHeaderMiddleware(ctx context.Context, headers http.Header) context.Context { - if len(headers) > 0 { - ctx = httpclient.WithContextualMiddleware(ctx, - httpclient.MiddlewareFunc(func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { - if !opts.ForwardHTTPHeaders { - return next - } - - return httpclient.RoundTripperFunc(func(qreq *http.Request) (*http.Response, error) { - // Only set a header if it is not already set. - for k, v := range headers { - if qreq.Header.Get(k) == "" { - for _, vv := range v { - qreq.Header.Add(k, vv) - } - } - } - return next.RoundTrip(qreq) - }) - })) - } - return ctx -} diff --git a/backend/adapter_utils_test.go b/backend/adapter_utils_test.go deleted file mode 100644 index 0161a2e41..000000000 --- a/backend/adapter_utils_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package backend - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestErrorWrapper(t *testing.T) { - t.Run("No downstream error should not set downstream error source in context", func(t *testing.T) { - ctx := initErrorSource(context.Background()) - - actualErr := errors.New("BOOM") - wrapper := errorWrapper(func(_ context.Context) (RequestStatus, error) { - return RequestStatusError, actualErr - }) - status, err := wrapper(ctx) - require.ErrorIs(t, err, actualErr) - require.Equal(t, RequestStatusError, status) - require.Equal(t, DefaultErrorSource, ErrorSourceFromContext(ctx)) - }) - - t.Run("Downstream error should set downstream error source in context", func(t *testing.T) { - ctx := initErrorSource(context.Background()) - - actualErr := errors.New("BOOM") - wrapper := errorWrapper(func(_ context.Context) (RequestStatus, error) { - return RequestStatusError, DownstreamError(actualErr) - }) - status, err := wrapper(ctx) - require.ErrorIs(t, err, actualErr) - require.Equal(t, RequestStatusError, status) - require.Equal(t, ErrorSourceDownstream, ErrorSourceFromContext(ctx)) - }) -} diff --git a/backend/admission_adapter.go b/backend/admission_adapter.go index 700fd31b8..1e90c41f1 100644 --- a/backend/admission_adapter.go +++ b/backend/admission_adapter.go @@ -18,15 +18,8 @@ func newAdmissionSDKAdapter(handler AdmissionHandler) *admissionSDKAdapter { } func (a *admissionSDKAdapter) ValidateAdmission(ctx context.Context, req *pluginv2.AdmissionRequest) (*pluginv2.ValidationResponse, error) { - ctx = setupContext(ctx, EndpointValidateAdmission) parsedReq := FromProto().AdmissionRequest(req) - - var resp *ValidationResponse - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - var innerErr error - resp, innerErr = a.handler.ValidateAdmission(ctx, parsedReq) - return RequestStatusFromError(innerErr), innerErr - }) + resp, err := a.handler.ValidateAdmission(ctx, parsedReq) if err != nil { return nil, err } @@ -35,15 +28,8 @@ func (a *admissionSDKAdapter) ValidateAdmission(ctx context.Context, req *plugin } func (a *admissionSDKAdapter) MutateAdmission(ctx context.Context, req *pluginv2.AdmissionRequest) (*pluginv2.MutationResponse, error) { - ctx = setupContext(ctx, EndpointMutateAdmission) parsedReq := FromProto().AdmissionRequest(req) - - var resp *MutationResponse - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - var innerErr error - resp, innerErr = a.handler.MutateAdmission(ctx, parsedReq) - return RequestStatusFromError(innerErr), innerErr - }) + resp, err := a.handler.MutateAdmission(ctx, parsedReq) if err != nil { return nil, err } diff --git a/backend/common.go b/backend/common.go index 3c1bfc295..885d00486 100644 --- a/backend/common.go +++ b/backend/common.go @@ -297,13 +297,6 @@ func SecureJSONDataFromHTTPClientOptions(opts httpclient.Options) (res map[strin return secureJSONData } -func propagateTenantIDIfPresent(ctx context.Context) context.Context { - if tid, exists := tenant.IDFromIncomingGRPCContext(ctx); exists { - ctx = tenant.WithTenant(ctx, tid) - } - return ctx -} - func (s *DataSourceInstanceSettings) ProxyOptionsFromContext(ctx context.Context) (*proxy.Options, error) { cfg := GrafanaConfigFromContext(ctx) p, err := cfg.proxy() @@ -383,3 +376,5 @@ func (s *DataSourceInstanceSettings) ProxyClient(ctx context.Context) (proxy.Cli func WithTenant(ctx context.Context, tenantID string) context.Context { return tenant.WithTenant(ctx, tenantID) } + +type handlerWrapperFunc func(ctx context.Context) (RequestStatus, error) diff --git a/backend/conversion_adapter.go b/backend/conversion_adapter.go index ba4eec0fd..e0b54be4a 100644 --- a/backend/conversion_adapter.go +++ b/backend/conversion_adapter.go @@ -64,24 +64,23 @@ func (a *conversionSDKAdapter) convertQueryDataRequest(ctx context.Context, requ } func (a *conversionSDKAdapter) ConvertObjects(ctx context.Context, req *pluginv2.ConversionRequest) (*pluginv2.ConversionResponse, error) { - ctx = setupContext(ctx, EndpointConvertObjects) parsedReq := FromProto().ConversionRequest(req) - resp := &ConversionResponse{} - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - var innerErr error - if a.queryConversionHandler != nil { - // Try to parse it as a query data request - reqs, err := parseAsQueryRequest(parsedReq) - if err == nil { - resp, innerErr = a.convertQueryDataRequest(ctx, reqs) - return RequestStatusFromError(innerErr), innerErr - } + var resp *ConversionResponse + var err error + if a.queryConversionHandler != nil { + // Try to parse it as a query data request + var reqs []*QueryDataRequest + reqs, err = parseAsQueryRequest(parsedReq) + if err == nil { + resp, err = a.convertQueryDataRequest(ctx, reqs) + } else { // The object cannot be parsed as a query data request, so we will try to convert it as a generic object + resp, err = a.handler.ConvertObjects(ctx, parsedReq) } - resp, innerErr = a.handler.ConvertObjects(ctx, parsedReq) - return RequestStatusFromError(innerErr), innerErr - }) + } else { + resp, err = a.handler.ConvertObjects(ctx, parsedReq) + } if err != nil { return nil, err } diff --git a/backend/data_adapter.go b/backend/data_adapter.go index 4b8071796..ec20ee908 100644 --- a/backend/data_adapter.go +++ b/backend/data_adapter.go @@ -3,9 +3,7 @@ package backend import ( "context" "errors" - "fmt" - "github.com/grafana/grafana-plugin-sdk-go/experimental/status" "github.com/grafana/grafana-plugin-sdk-go/genproto/pluginv2" ) @@ -21,78 +19,15 @@ func newDataSDKAdapter(handler QueryDataHandler) *dataSDKAdapter { } func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataRequest) (*pluginv2.QueryDataResponse, error) { - ctx = setupContext(ctx, EndpointQueryData) parsedReq := FromProto().QueryDataRequest(req) - - var resp *QueryDataResponse - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - ctx = withHeaderMiddleware(ctx, parsedReq.GetHTTPHeaders()) - var innerErr error - resp, innerErr = a.queryDataHandler.QueryData(ctx, parsedReq) - - requestStatus := RequestStatusFromQueryDataResponse(resp, innerErr) - if innerErr != nil { - return requestStatus, innerErr - } else if resp == nil { - return RequestStatusError, errors.New("both response and error are nil, but one must be provided") - } - ctxLogger := Logger.FromContext(ctx) - - // Set downstream status source in the context if there's at least one response with downstream status source, - // and if there's no plugin error - var hasPluginError, hasDownstreamError bool - for refID, r := range resp.Responses { - if r.Error == nil || status.IsCancelledError(r.Error) { - continue - } - - if !r.ErrorSource.IsValid() { - // if the error is a downstream error, set error source to downstream, otherwise plugin. - if IsDownstreamError(r.Error) { - r.ErrorSource = ErrorSourceDownstream - } else { - r.ErrorSource = ErrorSourcePlugin - } - resp.Responses[refID] = r - } - - if !r.Status.IsValid() { - r.Status = statusFromError(r.Error) - resp.Responses[refID] = r - } - - if r.ErrorSource == ErrorSourceDownstream { - hasDownstreamError = true - } else { - hasPluginError = true - } - - logParams := []any{ - "refID", refID, - "status", int(r.Status), - "error", r.Error, - "statusSource", string(r.ErrorSource), - } - ctxLogger.Error("Partial data response error", logParams...) - } - - // A plugin error has higher priority than a downstream error, - // so set to downstream only if there's no plugin error - if hasPluginError { - if err := WithErrorSource(ctx, ErrorSourcePlugin); err != nil { - return RequestStatusError, fmt.Errorf("failed to set plugin status source: %w", errors.Join(innerErr, err)) - } - } else if hasDownstreamError { - if err := WithDownstreamErrorSource(ctx); err != nil { - return RequestStatusError, fmt.Errorf("failed to set downstream status source: %w", errors.Join(innerErr, err)) - } - } - - return requestStatus, nil - }) + resp, err := a.queryDataHandler.QueryData(ctx, parsedReq) if err != nil { return nil, err } + if resp == nil { + return nil, errors.New("both response and error are nil, but one must be provided") + } + return ToProto().QueryDataResponse(resp) } diff --git a/backend/data_adapter_test.go b/backend/data_adapter_test.go index b61468d63..e4addcfa1 100644 --- a/backend/data_adapter_test.go +++ b/backend/data_adapter_test.go @@ -3,7 +3,6 @@ package backend import ( "bytes" "context" - "errors" "fmt" "io" "net/http" @@ -69,8 +68,13 @@ func TestQueryData(t *testing.T) { t.Run("When forward HTTP headers enabled should forward headers", func(t *testing.T) { ctx := context.Background() handler := newFakeDataHandlerWithOAuth() - adapter := newDataSDKAdapter(handler) - _, err := adapter.QueryData(ctx, &pluginv2.QueryDataRequest{ + handlers := Handlers{ + QueryDataHandler: handler, + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newHeaderMiddleware()) + require.NoError(t, err) + adapter := newDataSDKAdapter(handlerWithMw) + _, err = adapter.QueryData(ctx, &pluginv2.QueryDataRequest{ Headers: map[string]string{ "Authorization": "Bearer 123", }, @@ -95,8 +99,13 @@ func TestQueryData(t *testing.T) { t.Run("When forward HTTP headers disable should not forward headers", func(t *testing.T) { ctx := context.Background() handler := newFakeDataHandlerWithOAuth() - adapter := newDataSDKAdapter(handler) - _, err := adapter.QueryData(ctx, &pluginv2.QueryDataRequest{ + handlers := Handlers{ + QueryDataHandler: handler, + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newHeaderMiddleware()) + require.NoError(t, err) + adapter := newDataSDKAdapter(handlerWithMw) + _, err = adapter.QueryData(ctx, &pluginv2.QueryDataRequest{ Headers: map[string]string{ "Authorization": "Bearer 123", }, @@ -119,135 +128,24 @@ func TestQueryData(t *testing.T) { t.Run("When tenant information is attached to incoming context, it is propagated from adapter to handler", func(t *testing.T) { tid := "123456" - a := newDataSDKAdapter(QueryDataHandlerFunc(func(ctx context.Context, _ *QueryDataRequest) (*QueryDataResponse, error) { - require.Equal(t, tid, tenant.IDFromContext(ctx)) - return NewQueryDataResponse(), nil - })) + handlers := Handlers{ + QueryDataHandler: QueryDataHandlerFunc(func(ctx context.Context, _ *QueryDataRequest) (*QueryDataResponse, error) { + require.Equal(t, tid, tenant.IDFromContext(ctx)) + return NewQueryDataResponse(), nil + }), + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware()) + require.NoError(t, err) + a := newDataSDKAdapter(handlerWithMw) ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ tenant.CtxKey: tid, })) - _, err := a.QueryData(ctx, &pluginv2.QueryDataRequest{ - PluginContext: &pluginv2.PluginContext{}, - }) - require.NoError(t, err) - }) - - t.Run("TestQueryDataResponse", func(t *testing.T) { - someErr := errors.New("oops") - - for _, tc := range []struct { - name string - queryDataResponse *QueryDataResponse - expErrorSource ErrorSource - expError bool - }{ - { - name: `single downstream error should be "downstream" error source`, - queryDataResponse: &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: someErr, ErrorSource: ErrorSourceDownstream}, - }, - }, - expErrorSource: ErrorSourceDownstream, - }, - { - name: `single plugin error should be "plugin" error source`, - queryDataResponse: &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: someErr, ErrorSource: ErrorSourcePlugin}, - }, - }, - expErrorSource: ErrorSourcePlugin, - }, - { - name: `multiple downstream errors should be "downstream" error source`, - queryDataResponse: &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: someErr, ErrorSource: ErrorSourceDownstream}, - "B": {Error: someErr, ErrorSource: ErrorSourceDownstream}, - }, - }, - expErrorSource: ErrorSourceDownstream, - }, - { - name: `single plugin error mixed with downstream errors should be "plugin" error source`, - queryDataResponse: &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: someErr, ErrorSource: ErrorSourceDownstream}, - "B": {Error: someErr, ErrorSource: ErrorSourcePlugin}, - "C": {Error: someErr, ErrorSource: ErrorSourceDownstream}, - }, - }, - expErrorSource: ErrorSourcePlugin, - }, - { - name: `single downstream error without error source should be "downstream" error source`, - queryDataResponse: &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: DownstreamErrorf("boom")}, - }, - }, - expErrorSource: ErrorSourceDownstream, - }, - { - name: `multiple downstream error without error source and single plugin error should be "plugin" error source`, - queryDataResponse: &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: DownstreamErrorf("boom")}, - "B": {Error: someErr}, - "C": {Error: DownstreamErrorf("boom")}, - }, - }, - expErrorSource: ErrorSourcePlugin, - }, - { - name: "nil queryDataResponse and nil error should throw error", - queryDataResponse: nil, - expErrorSource: ErrorSourcePlugin, - expError: true, - }, - } { - t.Run(tc.name, func(t *testing.T) { - var actualCtx context.Context - a := newDataSDKAdapter(QueryDataHandlerFunc(func(ctx context.Context, _ *QueryDataRequest) (*QueryDataResponse, error) { - actualCtx = ctx - return tc.queryDataResponse, nil - })) - _, err := a.QueryData(context.Background(), &pluginv2.QueryDataRequest{ - PluginContext: &pluginv2.PluginContext{}, - }) - if tc.expError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - ss := ErrorSourceFromContext(actualCtx) - require.Equal(t, tc.expErrorSource, ss) - }) - } - }) - - t.Run("QueryData response without valid error source error should set error source", func(t *testing.T) { - someErr := errors.New("oops") - downstreamErr := DownstreamError(someErr) - a := newDataSDKAdapter(QueryDataHandlerFunc(func(_ context.Context, _ *QueryDataRequest) (*QueryDataResponse, error) { - return &QueryDataResponse{ - Responses: map[string]DataResponse{ - "A": {Error: someErr}, - "B": {Error: downstreamErr}, - }, - }, nil - })) - resp, err := a.QueryData(context.Background(), &pluginv2.QueryDataRequest{ + _, err = a.QueryData(ctx, &pluginv2.QueryDataRequest{ PluginContext: &pluginv2.PluginContext{}, }) - require.NoError(t, err) - require.Equal(t, ErrorSourcePlugin, ErrorSource(resp.Responses["A"].ErrorSource)) - require.Equal(t, ErrorSourceDownstream, ErrorSource(resp.Responses["B"].ErrorSource)) }) } diff --git a/backend/diagnostics_adapter.go b/backend/diagnostics_adapter.go index 97c3310ea..e6d0164cd 100644 --- a/backend/diagnostics_adapter.go +++ b/backend/diagnostics_adapter.go @@ -46,16 +46,8 @@ func (a *diagnosticsSDKAdapter) CollectMetrics(_ context.Context, _ *pluginv2.Co func (a *diagnosticsSDKAdapter) CheckHealth(ctx context.Context, protoReq *pluginv2.CheckHealthRequest) (*pluginv2.CheckHealthResponse, error) { if a.checkHealthHandler != nil { - ctx = setupContext(ctx, EndpointCheckHealth) parsedReq := FromProto().CheckHealthRequest(protoReq) - - var resp *CheckHealthResult - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - ctx = withHeaderMiddleware(ctx, parsedReq.GetHTTPHeaders()) - var innerErr error - resp, innerErr = a.checkHealthHandler.CheckHealth(ctx, parsedReq) - return RequestStatusFromError(innerErr), innerErr - }) + resp, err := a.checkHealthHandler.CheckHealth(ctx, parsedReq) if err != nil { return nil, err } diff --git a/backend/diagnostics_adapter_test.go b/backend/diagnostics_adapter_test.go index f6ba95dfe..e1b2aecc8 100644 --- a/backend/diagnostics_adapter_test.go +++ b/backend/diagnostics_adapter_test.go @@ -112,9 +112,12 @@ func TestCheckHealth(t *testing.T) { }) t.Run("When headers are present", func(t *testing.T) { - adapter := &diagnosticsSDKAdapter{ - checkHealthHandler: &testCheckHealthHandlerWithHeaders{}, + handlers := Handlers{ + CheckHealthHandler: &testCheckHealthHandlerWithHeaders{}, } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newHeaderMiddleware()) + require.NoError(t, err) + adapter := newDiagnosticsSDKAdapter(nil, handlerWithMw) res, err := adapter.CheckHealth(context.Background(), &pluginv2.CheckHealthRequest{ Headers: map[string]string{ "Authorization": "Bearer 123", @@ -128,15 +131,20 @@ func TestCheckHealth(t *testing.T) { t.Run("When tenant information is attached to incoming context, it is propagated from adapter to handler", func(t *testing.T) { tid := "123456" - a := newDiagnosticsSDKAdapter(nil, CheckHealthHandlerFunc(func(ctx context.Context, _ *CheckHealthRequest) (*CheckHealthResult, error) { - require.Equal(t, tid, tenant.IDFromContext(ctx)) - return &CheckHealthResult{}, nil - })) + handlers := Handlers{ + CheckHealthHandler: CheckHealthHandlerFunc(func(ctx context.Context, _ *CheckHealthRequest) (*CheckHealthResult, error) { + require.Equal(t, tid, tenant.IDFromContext(ctx)) + return &CheckHealthResult{}, nil + }), + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware()) + require.NoError(t, err) + a := newDiagnosticsSDKAdapter(nil, handlerWithMw) ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ tenant.CtxKey: tid, })) - _, err := a.CheckHealth(ctx, &pluginv2.CheckHealthRequest{ + _, err = a.CheckHealth(ctx, &pluginv2.CheckHealthRequest{ PluginContext: &pluginv2.PluginContext{}, }) require.NoError(t, err) diff --git a/backend/handler.go b/backend/handler.go index d79b3f630..eb7c21558 100644 --- a/backend/handler.go +++ b/backend/handler.go @@ -68,3 +68,16 @@ func (m *BaseHandler) MutateAdmission(ctx context.Context, req *AdmissionRequest func (m *BaseHandler) ConvertObjects(ctx context.Context, req *ConversionRequest) (*ConversionResponse, error) { return m.next.ConvertObjects(ctx, req) } + +// Handlers implements Handler. +type Handlers struct { + QueryDataHandler + CheckHealthHandler + CallResourceHandler + CollectMetricsHandler + StreamHandler + AdmissionHandler + ConversionHandler +} + +var _ Handler = &Handlers{} diff --git a/backend/http_headers.go b/backend/http_headers.go index 1e416d5fc..2f0bcaddc 100644 --- a/backend/http_headers.go +++ b/backend/http_headers.go @@ -1,10 +1,13 @@ package backend import ( + "context" "fmt" "net/http" "net/textproto" "strings" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" ) const ( @@ -94,3 +97,70 @@ func deleteHTTPHeaderInStringMap(headers map[string]string, key string) { } } } + +// newHeaderMiddleware creates a new handler middleware that forwards HTTP headers to outgoing +// HTTP request sent using the HTTP client from the httpclient package. +func newHeaderMiddleware() HandlerMiddleware { + return HandlerMiddlewareFunc(func(next Handler) Handler { + return &headerMiddleware{ + BaseHandler: NewBaseHandler(next), + } + }) +} + +// headerMiddleware a handler middleware that forwards HTTP headers to outgoing +// HTTP request sent using the HTTP client from the httpclient package. +type headerMiddleware struct { + BaseHandler +} + +func (m headerMiddleware) applyHeaders(ctx context.Context, headers http.Header) context.Context { + if len(headers) > 0 { + ctx = httpclient.WithContextualMiddleware(ctx, + httpclient.MiddlewareFunc(func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { + if !opts.ForwardHTTPHeaders { + return next + } + + return httpclient.RoundTripperFunc(func(qreq *http.Request) (*http.Response, error) { + // Only set a header if it is not already set. + for k, v := range headers { + if qreq.Header.Get(k) == "" { + for _, vv := range v { + qreq.Header.Add(k, vv) + } + } + } + return next.RoundTrip(qreq) + }) + })) + } + return ctx +} + +func (m *headerMiddleware) QueryData(ctx context.Context, req *QueryDataRequest) (*QueryDataResponse, error) { + if req == nil { + return m.BaseHandler.QueryData(ctx, req) + } + + ctx = m.applyHeaders(ctx, req.GetHTTPHeaders()) + return m.BaseHandler.QueryData(ctx, req) +} + +func (m *headerMiddleware) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { + if req == nil { + return m.BaseHandler.CallResource(ctx, req, sender) + } + + ctx = m.applyHeaders(ctx, req.GetHTTPHeaders()) + return m.BaseHandler.CallResource(ctx, req, sender) +} + +func (m *headerMiddleware) CheckHealth(ctx context.Context, req *CheckHealthRequest) (*CheckHealthResult, error) { + if req == nil { + return m.BaseHandler.CheckHealth(ctx, req) + } + + ctx = m.applyHeaders(ctx, req.GetHTTPHeaders()) + return m.BaseHandler.CheckHealth(ctx, req) +} diff --git a/backend/log.go b/backend/log.go index ffef48ccd..5a5a88df0 100644 --- a/backend/log.go +++ b/backend/log.go @@ -9,7 +9,7 @@ import ( ) // Logger is the default logger instance. This can be used directly to log from -// your plugin to grafana-server with calls like backend.Logger.Debug(...). +// your plugin to grafana-server with calls like Logger.Debug(...). var Logger = log.DefaultLogger // NewLoggerWith creates a new logger with the given arguments. @@ -46,3 +46,71 @@ func withContextualLogAttributes(ctx context.Context, pCtx PluginContext) contex ctx = log.WithContextualAttributes(ctx, args) return ctx } + +// newContextualLoggerMiddleware creates a new handler middleware that setup contextual logging. +func newContextualLoggerMiddleware() HandlerMiddleware { + return HandlerMiddlewareFunc(func(next Handler) Handler { + return &contextualLoggerMiddleware{ + BaseHandler: NewBaseHandler(next), + } + }) +} + +// contextualLoggerMiddleware a handler middleware that setup contextual logging. +type contextualLoggerMiddleware struct { + BaseHandler +} + +func (m *contextualLoggerMiddleware) setup(ctx context.Context, pCtx PluginContext) context.Context { + return withContextualLogAttributes(ctx, pCtx) +} + +func (m *contextualLoggerMiddleware) QueryData(ctx context.Context, req *QueryDataRequest) (*QueryDataResponse, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.QueryData(ctx, req) +} + +func (m *contextualLoggerMiddleware) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.CallResource(ctx, req, sender) +} + +func (m *contextualLoggerMiddleware) CheckHealth(ctx context.Context, req *CheckHealthRequest) (*CheckHealthResult, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.CheckHealth(ctx, req) +} + +func (m *contextualLoggerMiddleware) CollectMetrics(ctx context.Context, req *CollectMetricsRequest) (*CollectMetricsResult, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.CollectMetrics(ctx, req) +} + +func (m *contextualLoggerMiddleware) SubscribeStream(ctx context.Context, req *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.SubscribeStream(ctx, req) +} + +func (m *contextualLoggerMiddleware) PublishStream(ctx context.Context, req *PublishStreamRequest) (*PublishStreamResponse, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.PublishStream(ctx, req) +} + +func (m *contextualLoggerMiddleware) RunStream(ctx context.Context, req *RunStreamRequest, sender *StreamSender) error { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.RunStream(ctx, req, sender) +} + +func (m *contextualLoggerMiddleware) ValidateAdmission(ctx context.Context, req *AdmissionRequest) (*ValidationResponse, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.ValidateAdmission(ctx, req) +} + +func (m *contextualLoggerMiddleware) MutateAdmission(ctx context.Context, req *AdmissionRequest) (*MutationResponse, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.MutateAdmission(ctx, req) +} + +func (m *contextualLoggerMiddleware) ConvertObjects(ctx context.Context, req *ConversionRequest) (*ConversionResponse, error) { + ctx = m.setup(ctx, req.PluginContext) + return m.BaseHandler.ConvertObjects(ctx, req) +} diff --git a/backend/log_test.go b/backend/log_test.go index 9639a1fc3..66c36e4c2 100644 --- a/backend/log_test.go +++ b/backend/log_test.go @@ -36,12 +36,18 @@ func TestContextualLogger(t *testing.T) { pCtx := &pluginv2.PluginContext{PluginId: pluginID} t.Run("DataSDKAdapter", func(t *testing.T) { run := make(chan struct{}, 1) - a := newDataSDKAdapter(QueryDataHandlerFunc(func(ctx context.Context, _ *QueryDataRequest) (*QueryDataResponse, error) { + handler := QueryDataHandlerFunc(func(ctx context.Context, _ *QueryDataRequest) (*QueryDataResponse, error) { checkCtxLogger(ctx, t, map[string]any{"endpoint": "queryData", "pluginId": pluginID}) run <- struct{}{} return NewQueryDataResponse(), nil - })) - _, err := a.QueryData(context.Background(), &pluginv2.QueryDataRequest{ + }) + handlers := Handlers{ + QueryDataHandler: handler, + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware(), newContextualLoggerMiddleware()) + require.NoError(t, err) + a := newDataSDKAdapter(handlerWithMw) + _, err = a.QueryData(context.Background(), &pluginv2.QueryDataRequest{ PluginContext: pCtx, }) require.NoError(t, err) @@ -50,12 +56,18 @@ func TestContextualLogger(t *testing.T) { t.Run("DiagnosticsSDKAdapter", func(t *testing.T) { run := make(chan struct{}, 1) - a := newDiagnosticsSDKAdapter(prometheus.DefaultGatherer, CheckHealthHandlerFunc(func(ctx context.Context, _ *CheckHealthRequest) (*CheckHealthResult, error) { + handler := CheckHealthHandlerFunc(func(ctx context.Context, _ *CheckHealthRequest) (*CheckHealthResult, error) { checkCtxLogger(ctx, t, map[string]any{"endpoint": "checkHealth", "pluginId": pluginID}) run <- struct{}{} return &CheckHealthResult{}, nil - })) - _, err := a.CheckHealth(context.Background(), &pluginv2.CheckHealthRequest{ + }) + handlers := Handlers{ + CheckHealthHandler: handler, + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware(), newContextualLoggerMiddleware()) + require.NoError(t, err) + a := newDiagnosticsSDKAdapter(prometheus.DefaultGatherer, handlerWithMw) + _, err = a.CheckHealth(context.Background(), &pluginv2.CheckHealthRequest{ PluginContext: pCtx, }) require.NoError(t, err) @@ -64,12 +76,18 @@ func TestContextualLogger(t *testing.T) { t.Run("ResourceSDKAdapter", func(t *testing.T) { run := make(chan struct{}, 1) - a := newResourceSDKAdapter(CallResourceHandlerFunc(func(ctx context.Context, _ *CallResourceRequest, _ CallResourceResponseSender) error { + handler := CallResourceHandlerFunc(func(ctx context.Context, _ *CallResourceRequest, _ CallResourceResponseSender) error { checkCtxLogger(ctx, t, map[string]any{"endpoint": "callResource", "pluginId": pluginID}) run <- struct{}{} return nil - })) - err := a.CallResource(&pluginv2.CallResourceRequest{ + }) + handlers := Handlers{ + CallResourceHandler: handler, + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware(), newContextualLoggerMiddleware()) + require.NoError(t, err) + a := newResourceSDKAdapter(handlerWithMw) + err = a.CallResource(&pluginv2.CallResourceRequest{ PluginContext: pCtx, }, newTestCallResourceServer()) require.NoError(t, err) @@ -80,23 +98,28 @@ func TestContextualLogger(t *testing.T) { subscribeStreamRun := make(chan struct{}, 1) publishStreamRun := make(chan struct{}, 1) runStreamRun := make(chan struct{}, 1) - a := newStreamSDKAdapter(&streamAdapter{ - subscribeStreamFunc: func(ctx context.Context, _ *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { - checkCtxLogger(ctx, t, map[string]any{"endpoint": "subscribeStream", "pluginId": pluginID}) - subscribeStreamRun <- struct{}{} - return &SubscribeStreamResponse{}, nil - }, - publishStreamFunc: func(ctx context.Context, _ *PublishStreamRequest) (*PublishStreamResponse, error) { - checkCtxLogger(ctx, t, map[string]any{"endpoint": "publishStream", "pluginId": pluginID}) - publishStreamRun <- struct{}{} - return &PublishStreamResponse{}, nil - }, - runStreamFunc: func(ctx context.Context, _ *RunStreamRequest, _ *StreamSender) error { - checkCtxLogger(ctx, t, map[string]any{"endpoint": "runStream", "pluginId": pluginID}) - runStreamRun <- struct{}{} - return nil + handlers := Handlers{ + StreamHandler: &streamAdapter{ + subscribeStreamFunc: func(ctx context.Context, _ *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { + checkCtxLogger(ctx, t, map[string]any{"endpoint": "subscribeStream", "pluginId": pluginID}) + subscribeStreamRun <- struct{}{} + return &SubscribeStreamResponse{}, nil + }, + publishStreamFunc: func(ctx context.Context, _ *PublishStreamRequest) (*PublishStreamResponse, error) { + checkCtxLogger(ctx, t, map[string]any{"endpoint": "publishStream", "pluginId": pluginID}) + publishStreamRun <- struct{}{} + return &PublishStreamResponse{}, nil + }, + runStreamFunc: func(ctx context.Context, _ *RunStreamRequest, _ *StreamSender) error { + checkCtxLogger(ctx, t, map[string]any{"endpoint": "runStream", "pluginId": pluginID}) + runStreamRun <- struct{}{} + return nil + }, }, - }) + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware(), newContextualLoggerMiddleware()) + require.NoError(t, err) + a := newStreamSDKAdapter(handlerWithMw) t.Run("SubscribeStream", func(t *testing.T) { _, err := a.SubscribeStream(context.Background(), &pluginv2.SubscribeStreamRequest{ diff --git a/backend/resource_adapter.go b/backend/resource_adapter.go index d33158482..86d74b1df 100644 --- a/backend/resource_adapter.go +++ b/backend/resource_adapter.go @@ -1,7 +1,6 @@ package backend import ( - "context" "net/http" "github.com/grafana/grafana-plugin-sdk-go/genproto/pluginv2" @@ -30,12 +29,6 @@ func (a *resourceSDKAdapter) CallResource(protoReq *pluginv2.CallResourceRequest }) ctx := protoSrv.Context() - ctx = setupContext(ctx, EndpointCallResource) parsedReq := FromProto().CallResourceRequest(protoReq) - - return wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - ctx = withHeaderMiddleware(ctx, parsedReq.GetHTTPHeaders()) - err := a.callResourceHandler.CallResource(ctx, parsedReq, fn) - return RequestStatusFromError(err), err - }) + return a.callResourceHandler.CallResource(ctx, parsedReq, fn) } diff --git a/backend/resource_adapter_test.go b/backend/resource_adapter_test.go index 42c00464d..f3c7d9a74 100644 --- a/backend/resource_adapter_test.go +++ b/backend/resource_adapter_test.go @@ -144,8 +144,13 @@ func TestCallResource(t *testing.T) { t.Run("When oauth headers are set it should set the middleware to set headers", func(t *testing.T) { testSender := newTestCallResourceServer() - adapter := newResourceSDKAdapter(&testCallResourceWithHeaders{}) - err := adapter.CallResource(&pluginv2.CallResourceRequest{ + handlers := Handlers{ + CallResourceHandler: &testCallResourceWithHeaders{}, + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newHeaderMiddleware()) + require.NoError(t, err) + adapter := newResourceSDKAdapter(handlerWithMw) + err = adapter.CallResource(&pluginv2.CallResourceRequest{ PluginContext: &pluginv2.PluginContext{}, Headers: map[string]*pluginv2.StringList{ "Authorization": { @@ -158,17 +163,22 @@ func TestCallResource(t *testing.T) { t.Run("When tenant information is attached to incoming context, it is propagated from adapter to handler", func(t *testing.T) { tid := "123456" - a := newResourceSDKAdapter(CallResourceHandlerFunc(func(ctx context.Context, _ *CallResourceRequest, _ CallResourceResponseSender) error { - require.Equal(t, tid, tenant.IDFromContext(ctx)) - return nil - })) + handlers := Handlers{ + CallResourceHandler: CallResourceHandlerFunc(func(ctx context.Context, _ *CallResourceRequest, _ CallResourceResponseSender) error { + require.Equal(t, tid, tenant.IDFromContext(ctx)) + return nil + }), + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware()) + require.NoError(t, err) + a := newResourceSDKAdapter(handlerWithMw) testSender := newTestCallResourceServer() testSender.WithContext(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ tenant.CtxKey: tid, }))) - err := a.CallResource(&pluginv2.CallResourceRequest{ + err = a.CallResource(&pluginv2.CallResourceRequest{ PluginContext: &pluginv2.PluginContext{}, }, testSender) require.NoError(t, err) diff --git a/backend/serve.go b/backend/serve.go index 6480820ee..c9f3c2e56 100644 --- a/backend/serve.go +++ b/backend/serve.go @@ -23,6 +23,7 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend/grpcplugin" "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/grafana/grafana-plugin-sdk-go/backend/tracing" "github.com/grafana/grafana-plugin-sdk-go/genproto/pluginv2" "github.com/grafana/grafana-plugin-sdk-go/internal/standalone" "github.com/grafana/grafana-plugin-sdk-go/internal/tracerprovider" @@ -71,33 +72,54 @@ type ServeOpts struct { // GRPCSettings settings for gPRC. GRPCSettings GRPCSettings + + // HandlerMiddlewares list of handler middlewares to decorate handlers with. + HandlerMiddlewares []HandlerMiddleware +} + +func (opts ServeOpts) HandlerWithMiddlewares() (Handler, error) { + handlers := Handlers{ + CheckHealthHandler: opts.CheckHealthHandler, + CallResourceHandler: opts.CallResourceHandler, + QueryDataHandler: opts.QueryDataHandler, + StreamHandler: opts.StreamHandler, + AdmissionHandler: opts.AdmissionHandler, + ConversionHandler: opts.ConversionHandler, + } + + return HandlerFromMiddlewares(handlers, opts.HandlerMiddlewares...) } -func GRPCServeOpts(opts ServeOpts) grpcplugin.ServeOpts { +func GRPCServeOpts(opts ServeOpts) (grpcplugin.ServeOpts, error) { + handler, err := opts.HandlerWithMiddlewares() + if err != nil { + return grpcplugin.ServeOpts{}, fmt.Errorf("failed to create handler with middlewares: %w", err) + } + pluginOpts := grpcplugin.ServeOpts{ - DiagnosticsServer: newDiagnosticsSDKAdapter(prometheus.DefaultGatherer, opts.CheckHealthHandler), + DiagnosticsServer: newDiagnosticsSDKAdapter(prometheus.DefaultGatherer, handler), } if opts.CallResourceHandler != nil { - pluginOpts.ResourceServer = newResourceSDKAdapter(opts.CallResourceHandler) + pluginOpts.ResourceServer = newResourceSDKAdapter(handler) } if opts.QueryDataHandler != nil { - pluginOpts.DataServer = newDataSDKAdapter(opts.QueryDataHandler) + pluginOpts.DataServer = newDataSDKAdapter(handler) } if opts.StreamHandler != nil { - pluginOpts.StreamServer = newStreamSDKAdapter(opts.StreamHandler) + pluginOpts.StreamServer = newStreamSDKAdapter(handler) } if opts.AdmissionHandler != nil { - pluginOpts.AdmissionServer = newAdmissionSDKAdapter(opts.AdmissionHandler) + pluginOpts.AdmissionServer = newAdmissionSDKAdapter(handler) } if opts.ConversionHandler != nil || opts.QueryConversionHandler != nil { - pluginOpts.ConversionServer = newConversionSDKAdapter(opts.ConversionHandler, opts.QueryConversionHandler) + pluginOpts.ConversionServer = newConversionSDKAdapter(handler, opts.QueryConversionHandler) } - return pluginOpts + return pluginOpts, nil } // grpcServerOptions returns a new []grpc.ServerOption that can be passed to grpc.NewServer. @@ -147,7 +169,11 @@ func defaultGRPCMiddlewares(opts ServeOpts) []grpc.ServerOption { // Deprecated: Serve exists for historical compatibility // and might be removed in a future version. Please migrate to use [Manage] instead. func Serve(opts ServeOpts) error { - pluginOpts := GRPCServeOpts(opts) + pluginOpts, err := GRPCServeOpts(opts) + if err != nil { + return err + } + pluginOpts.GRPCServer = func(customOptions []grpc.ServerOption) *grpc.Server { return grpc.NewServer(grpcServerOptions(opts, customOptions...)...) } @@ -198,7 +224,11 @@ func GracefulStandaloneServe(dsopts ServeOpts, info standalone.ServerSettings) e standalone.FindAndKillCurrentPlugin(info.Dir) // Start GRPC server - pluginOpts := GRPCServeOpts(dsopts) + pluginOpts, err := GRPCServeOpts(dsopts) + if err != nil { + return err + } + if pluginOpts.GRPCServer == nil { pluginOpts.GRPCServer = func(customOptions []grpc.ServerOption) *grpc.Server { return grpc.NewServer(grpcServerOptions(dsopts, customOptions...)...) @@ -291,6 +321,13 @@ func Manage(pluginID string, serveOpts ServeOpts) error { } }() + if serveOpts.HandlerMiddlewares == nil { + serveOpts.HandlerMiddlewares = make([]HandlerMiddleware, 0) + } + + middlewares := defaultHandlerMiddlewares() + serveOpts.HandlerMiddlewares = append(middlewares, serveOpts.HandlerMiddlewares...) + if s, enabled := standalone.ServerModeEnabled(pluginID); enabled { // Run the standalone GRPC server return GracefulStandaloneServe(serveOpts, s) @@ -310,7 +347,18 @@ func Manage(pluginID string, serveOpts ServeOpts) error { // TestStandaloneServe starts a gRPC server that is not managed by hashicorp. // The function returns the gRPC server which should be closed by the consumer. func TestStandaloneServe(opts ServeOpts, address string) (*grpc.Server, error) { - pluginOpts := GRPCServeOpts(opts) + if opts.HandlerMiddlewares == nil { + opts.HandlerMiddlewares = make([]HandlerMiddleware, 0) + } + + middlewares := defaultHandlerMiddlewares() + opts.HandlerMiddlewares = append(middlewares, opts.HandlerMiddlewares...) + + pluginOpts, err := GRPCServeOpts(opts) + if err != nil { + return nil, err + } + if pluginOpts.GRPCServer == nil { pluginOpts.GRPCServer = func(customOptions []grpc.ServerOption) *grpc.Server { return grpc.NewServer(grpcServerOptions(opts, customOptions...)...) @@ -373,3 +421,15 @@ func TestStandaloneServe(opts ServeOpts, address string) (*grpc.Server, error) { return server, nil } + +func defaultHandlerMiddlewares() []HandlerMiddleware { + return []HandlerMiddleware{ + newTenantIDMiddleware(), + newContextualLoggerMiddleware(), + NewTracingMiddleware(tracing.DefaultTracer()), + NewMetricsMiddleware(prometheus.DefaultRegisterer, "grafana", false), + NewLoggerMiddleware(Logger, nil), + newHeaderMiddleware(), + NewErrorSourceMiddleware(), + } +} diff --git a/backend/stream_adapter.go b/backend/stream_adapter.go index d67c68cda..6a0fa1c42 100644 --- a/backend/stream_adapter.go +++ b/backend/stream_adapter.go @@ -25,15 +25,8 @@ func (a *streamSDKAdapter) SubscribeStream(ctx context.Context, protoReq *plugin return nil, status.Error(codes.Unimplemented, "not implemented") } - ctx = setupContext(ctx, EndpointSubscribeStream) parsedReq := FromProto().SubscribeStreamRequest(protoReq) - - var resp *SubscribeStreamResponse - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - var innerErr error - resp, innerErr = a.streamHandler.SubscribeStream(ctx, parsedReq) - return RequestStatusFromError(innerErr), innerErr - }) + resp, err := a.streamHandler.SubscribeStream(ctx, parsedReq) if err != nil { return nil, err } @@ -46,15 +39,8 @@ func (a *streamSDKAdapter) PublishStream(ctx context.Context, protoReq *pluginv2 return nil, status.Error(codes.Unimplemented, "not implemented") } - ctx = setupContext(ctx, EndpointPublishStream) parsedReq := FromProto().PublishStreamRequest(protoReq) - - var resp *PublishStreamResponse - err := wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - var innerErr error - resp, innerErr = a.streamHandler.PublishStream(ctx, parsedReq) - return RequestStatusFromError(innerErr), innerErr - }) + resp, err := a.streamHandler.PublishStream(ctx, parsedReq) if err != nil { return nil, err } @@ -75,12 +61,7 @@ func (a *streamSDKAdapter) RunStream(protoReq *pluginv2.RunStreamRequest, protoS return status.Error(codes.Unimplemented, "not implemented") } ctx := protoSrv.Context() - ctx = setupContext(ctx, EndpointRunStream) parsedReq := FromProto().RunStreamRequest(protoReq) - - return wrapHandler(ctx, parsedReq.PluginContext, func(ctx context.Context) (RequestStatus, error) { - sender := NewStreamSender(&runStreamServer{protoSrv: protoSrv}) - err := a.streamHandler.RunStream(ctx, parsedReq, sender) - return RequestStatusFromError(err), err - }) + sender := NewStreamSender(&runStreamServer{protoSrv: protoSrv}) + return a.streamHandler.RunStream(ctx, parsedReq, sender) } diff --git a/backend/stream_adapter_test.go b/backend/stream_adapter_test.go index a1784fe6f..0b1793735 100644 --- a/backend/stream_adapter_test.go +++ b/backend/stream_adapter_test.go @@ -14,18 +14,23 @@ import ( func TestSubscribeStream(t *testing.T) { t.Run("When tenant information is attached to incoming context, it is propagated from adapter to handler", func(t *testing.T) { tid := "123456" - a := newStreamSDKAdapter(&streamAdapter{ - subscribeStreamFunc: func(ctx context.Context, _ *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { - require.Equal(t, tid, tenant.IDFromContext(ctx)) - return &SubscribeStreamResponse{}, nil + handlers := Handlers{ + StreamHandler: &streamAdapter{ + subscribeStreamFunc: func(ctx context.Context, _ *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { + require.Equal(t, tid, tenant.IDFromContext(ctx)) + return &SubscribeStreamResponse{}, nil + }, }, - }) + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware()) + require.NoError(t, err) + a := newStreamSDKAdapter(handlerWithMw) ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ tenant.CtxKey: tid, })) - _, err := a.SubscribeStream(ctx, &pluginv2.SubscribeStreamRequest{ + _, err = a.SubscribeStream(ctx, &pluginv2.SubscribeStreamRequest{ PluginContext: &pluginv2.PluginContext{}, }) require.NoError(t, err) @@ -35,18 +40,23 @@ func TestSubscribeStream(t *testing.T) { func TestPublishStream(t *testing.T) { t.Run("When tenant information is attached to incoming context, it is propagated from adapter to handler", func(t *testing.T) { tid := "123456" - a := newStreamSDKAdapter(&streamAdapter{ - publishStreamFunc: func(ctx context.Context, _ *PublishStreamRequest) (*PublishStreamResponse, error) { - require.Equal(t, tid, tenant.IDFromContext(ctx)) - return &PublishStreamResponse{}, nil + handlers := Handlers{ + StreamHandler: &streamAdapter{ + publishStreamFunc: func(ctx context.Context, _ *PublishStreamRequest) (*PublishStreamResponse, error) { + require.Equal(t, tid, tenant.IDFromContext(ctx)) + return &PublishStreamResponse{}, nil + }, }, - }) + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware()) + require.NoError(t, err) + a := newStreamSDKAdapter(handlerWithMw) ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ tenant.CtxKey: tid, })) - _, err := a.PublishStream(ctx, &pluginv2.PublishStreamRequest{ + _, err = a.PublishStream(ctx, &pluginv2.PublishStreamRequest{ PluginContext: &pluginv2.PluginContext{}, }) require.NoError(t, err) @@ -56,19 +66,24 @@ func TestPublishStream(t *testing.T) { func TestRunStream(t *testing.T) { t.Run("When tenant information is attached to incoming context, it is propagated from adapter to handler", func(t *testing.T) { tid := "123456" - a := newStreamSDKAdapter(&streamAdapter{ - runStreamFunc: func(ctx context.Context, _ *RunStreamRequest, _ *StreamSender) error { - require.Equal(t, tid, tenant.IDFromContext(ctx)) - return nil + handlers := Handlers{ + StreamHandler: &streamAdapter{ + runStreamFunc: func(ctx context.Context, _ *RunStreamRequest, _ *StreamSender) error { + require.Equal(t, tid, tenant.IDFromContext(ctx)) + return nil + }, }, - }) + } + handlerWithMw, err := HandlerFromMiddlewares(handlers, newTenantIDMiddleware()) + require.NoError(t, err) + a := newStreamSDKAdapter(handlerWithMw) testSrv := newTestRunStreamServer() testSrv.WithContext(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ tenant.CtxKey: tid, }))) - err := a.RunStream(&pluginv2.RunStreamRequest{ + err = a.RunStream(&pluginv2.RunStreamRequest{ PluginContext: &pluginv2.PluginContext{}, }, testSrv) require.NoError(t, err) diff --git a/backend/tenant_middleware.go b/backend/tenant_middleware.go new file mode 100644 index 000000000..c4a8692c6 --- /dev/null +++ b/backend/tenant_middleware.go @@ -0,0 +1,78 @@ +package backend + +import ( + "context" + + "github.com/grafana/grafana-plugin-sdk-go/internal/tenant" +) + +// newTenantIDMiddleware creates a new handler middleware that extract tenant ID from the incoming gRPC context, if available. +func newTenantIDMiddleware() HandlerMiddleware { + return HandlerMiddlewareFunc(func(next Handler) Handler { + return &tenantIDMiddleware{ + BaseHandler: NewBaseHandler(next), + } + }) +} + +// tenantIDMiddleware a handler middleware that extract tenant ID from the incoming gRPC context, if available. +type tenantIDMiddleware struct { + BaseHandler +} + +func (m *tenantIDMiddleware) setup(ctx context.Context) context.Context { + if tid, exists := tenant.IDFromIncomingGRPCContext(ctx); exists { + ctx = tenant.WithTenant(ctx, tid) + } + return ctx +} + +func (m *tenantIDMiddleware) QueryData(ctx context.Context, req *QueryDataRequest) (*QueryDataResponse, error) { + ctx = m.setup(ctx) + return m.BaseHandler.QueryData(ctx, req) +} + +func (m *tenantIDMiddleware) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { + ctx = m.setup(ctx) + return m.BaseHandler.CallResource(ctx, req, sender) +} + +func (m *tenantIDMiddleware) CheckHealth(ctx context.Context, req *CheckHealthRequest) (*CheckHealthResult, error) { + ctx = m.setup(ctx) + return m.BaseHandler.CheckHealth(ctx, req) +} + +func (m *tenantIDMiddleware) CollectMetrics(ctx context.Context, req *CollectMetricsRequest) (*CollectMetricsResult, error) { + ctx = m.setup(ctx) + return m.BaseHandler.CollectMetrics(ctx, req) +} + +func (m *tenantIDMiddleware) SubscribeStream(ctx context.Context, req *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { + ctx = m.setup(ctx) + return m.BaseHandler.SubscribeStream(ctx, req) +} + +func (m *tenantIDMiddleware) PublishStream(ctx context.Context, req *PublishStreamRequest) (*PublishStreamResponse, error) { + ctx = m.setup(ctx) + return m.BaseHandler.PublishStream(ctx, req) +} + +func (m *tenantIDMiddleware) RunStream(ctx context.Context, req *RunStreamRequest, sender *StreamSender) error { + ctx = m.setup(ctx) + return m.BaseHandler.RunStream(ctx, req, sender) +} + +func (m *tenantIDMiddleware) ValidateAdmission(ctx context.Context, req *AdmissionRequest) (*ValidationResponse, error) { + ctx = m.setup(ctx) + return m.BaseHandler.ValidateAdmission(ctx, req) +} + +func (m *tenantIDMiddleware) MutateAdmission(ctx context.Context, req *AdmissionRequest) (*MutationResponse, error) { + ctx = m.setup(ctx) + return m.BaseHandler.MutateAdmission(ctx, req) +} + +func (m *tenantIDMiddleware) ConvertObjects(ctx context.Context, req *ConversionRequest) (*ConversionResponse, error) { + ctx = m.setup(ctx) + return m.BaseHandler.ConvertObjects(ctx, req) +} diff --git a/backend/tracing_middleware.go b/backend/tracing_middleware.go new file mode 100644 index 000000000..c67e6afb6 --- /dev/null +++ b/backend/tracing_middleware.go @@ -0,0 +1,159 @@ +package backend + +import ( + "context" + "fmt" + + "github.com/grafana/grafana-plugin-sdk-go/backend/tracing" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +// NewTracingMiddleware creates a new HandlerMiddleware that will +// create traces/spans for requests. +func NewTracingMiddleware(tracer trace.Tracer) HandlerMiddleware { + return HandlerMiddlewareFunc(func(next Handler) Handler { + return &tracingMiddleware{ + BaseHandler: NewBaseHandler(next), + tracer: tracer, + } + }) +} + +type tracingMiddleware struct { + BaseHandler + tracer trace.Tracer +} + +func (m *tracingMiddleware) traceRequest(ctx context.Context, pCtx PluginContext, fn func(context.Context) (RequestStatus, error)) error { + endpoint := EndpointFromContext(ctx) + ctx, span := m.tracer.Start(ctx, fmt.Sprintf("sdk.%s", endpoint), trace.WithAttributes( + attribute.String("plugin_id", pCtx.PluginID), + attribute.Int64("org_id", pCtx.OrgID), + )) + defer span.End() + + if pCtx.DataSourceInstanceSettings != nil { + span.SetAttributes( + attribute.String("datasource_name", pCtx.DataSourceInstanceSettings.Name), + attribute.String("datasource_uid", pCtx.DataSourceInstanceSettings.UID), + ) + } + + if u := pCtx.User; u != nil { + span.SetAttributes(attribute.String("user", pCtx.User.Name)) + } + + status, err := fn(ctx) + + span.SetAttributes( + attribute.String("request_status", status.String()), + attribute.String("status_source", string(ErrorSourceFromContext(ctx))), + ) + + if err != nil { + return tracing.Error(span, err) + } + + return nil +} + +func (m *tracingMiddleware) QueryData(ctx context.Context, req *QueryDataRequest) (*QueryDataResponse, error) { + var resp *QueryDataResponse + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.QueryData(ctx, req) + return RequestStatusFromQueryDataResponse(resp, innerErr), innerErr + }) + + return resp, err +} + +func (m *tracingMiddleware) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { + return m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + innerErr := m.BaseHandler.CallResource(ctx, req, sender) + return RequestStatusFromError(innerErr), innerErr + }) +} + +func (m *tracingMiddleware) CheckHealth(ctx context.Context, req *CheckHealthRequest) (*CheckHealthResult, error) { + var resp *CheckHealthResult + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.CheckHealth(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + + return resp, err +} + +func (m *tracingMiddleware) CollectMetrics(ctx context.Context, req *CollectMetricsRequest) (*CollectMetricsResult, error) { + var resp *CollectMetricsResult + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.CollectMetrics(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + return resp, err +} + +func (m *tracingMiddleware) SubscribeStream(ctx context.Context, req *SubscribeStreamRequest) (*SubscribeStreamResponse, error) { + var resp *SubscribeStreamResponse + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.SubscribeStream(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + return resp, err +} + +func (m *tracingMiddleware) PublishStream(ctx context.Context, req *PublishStreamRequest) (*PublishStreamResponse, error) { + var resp *PublishStreamResponse + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.PublishStream(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + return resp, err +} + +func (m *tracingMiddleware) RunStream(ctx context.Context, req *RunStreamRequest, sender *StreamSender) error { + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + innerErr := m.BaseHandler.RunStream(ctx, req, sender) + return RequestStatusFromError(innerErr), innerErr + }) + return err +} + +func (m *tracingMiddleware) ValidateAdmission(ctx context.Context, req *AdmissionRequest) (*ValidationResponse, error) { + var resp *ValidationResponse + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.ValidateAdmission(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + + return resp, err +} + +func (m *tracingMiddleware) MutateAdmission(ctx context.Context, req *AdmissionRequest) (*MutationResponse, error) { + var resp *MutationResponse + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.MutateAdmission(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + + return resp, err +} + +func (m *tracingMiddleware) ConvertObjects(ctx context.Context, req *ConversionRequest) (*ConversionResponse, error) { + var resp *ConversionResponse + err := m.traceRequest(ctx, req.PluginContext, func(ctx context.Context) (RequestStatus, error) { + var innerErr error + resp, innerErr = m.BaseHandler.ConvertObjects(ctx, req) + return RequestStatusFromError(innerErr), innerErr + }) + + return resp, err +}