diff --git a/backend/error_source.go b/backend/error_source.go index 8c157cf30..526334f40 100644 --- a/backend/error_source.go +++ b/backend/error_source.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" ) // ErrorSource type defines the source of the error @@ -25,6 +27,13 @@ func (es ErrorSource) IsValid() bool { return es == ErrorSourceDownstream || es == ErrorSourcePlugin } +func ErrorSourceFromHTTPError(err error) ErrorSource { + if httpclient.IsDownstreamHTTPError(err) { + return ErrorSourceDownstream + } + return ErrorSourcePlugin +} + // ErrorSourceFromStatus returns an [ErrorSource] based on provided HTTP status code. func ErrorSourceFromHTTPStatus(statusCode int) ErrorSource { switch statusCode { diff --git a/backend/httpclient/error_source_middleware.go b/backend/httpclient/error_source_middleware.go new file mode 100644 index 000000000..0d7a383ad --- /dev/null +++ b/backend/httpclient/error_source_middleware.go @@ -0,0 +1,136 @@ +package httpclient + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "syscall" + + grpccodes "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" +) + +const ErrorSourceMiddlewareName = "ErrorSource" + +func ErrorSourceMiddleware() Middleware { + return NamedMiddlewareFunc(ErrorSourceMiddlewareName, func(_ Options, next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + res, err := next.RoundTrip(req) + if err != nil && IsDownstreamHTTPError(err) { + return res, DownstreamError(err) + } + + return res, err + }) + }) +} + +type ErrorSource string + +const ( + ErrorSourcePlugin ErrorSource = "plugin" + ErrorSourceDownstream ErrorSource = "downstream" +) + +type errorWithSourceImpl struct { + source ErrorSource + err error +} + +func IsDownstreamHTTPError(err error) bool { + e := errorWithSourceImpl{ + source: ErrorSourceDownstream, + } + if errors.Is(err, e) { + return true + } + + // nolint:errorlint + if errWithSource, ok := err.(errorWithSourceImpl); ok && errWithSource.ErrorSource() == ErrorSourceDownstream { + return true + } + + // Check if the error is a HTTP timeout error or a context cancelled error + if isHTTPTimeoutError(err) { + return true + } + + if isCancelledError(err) { + return true + } + + if isConnectionResetOrRefusedError(err) { + return true + } + + if isDNSNotFoundError(err) { + return true + } + + return false +} + +func isCancelledError(err error) bool { + return errors.Is(err, context.Canceled) || grpcstatus.Code(err) == grpccodes.Canceled +} + +func isHTTPTimeoutError(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + return errors.Is(err, os.ErrDeadlineExceeded) // replacement for os.IsTimeout(err) +} + +func isConnectionResetOrRefusedError(err error) bool { + var netErr *net.OpError + if errors.As(err, &netErr) { + var sysErr *os.SyscallError + if errors.As(netErr.Err, &sysErr) { + return errors.Is(sysErr.Err, syscall.ECONNRESET) || errors.Is(sysErr.Err, syscall.ECONNREFUSED) + } + } + + return false +} + +func isDNSNotFoundError(err error) bool { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return true + } + + return false +} + +func (e errorWithSourceImpl) ErrorSource() ErrorSource { + return e.source +} + +func (e errorWithSourceImpl) Error() string { + return fmt.Errorf("%s error: %w", e.source, e.err).Error() +} + +// Implements the interface used by [errors.Is]. +func (e errorWithSourceImpl) Is(err error) bool { + if errWithSource, ok := err.(errorWithSourceImpl); ok { + return errWithSource.ErrorSource() == e.source + } + + return false +} + +func (e errorWithSourceImpl) Unwrap() error { + return e.err +} + +func DownstreamError(err error) error { + return errorWithSourceImpl{ + source: ErrorSourceDownstream, + err: err, + } +} diff --git a/backend/httpclient/http_client.go b/backend/httpclient/http_client.go index 40a4d39d4..0d88f72e8 100644 --- a/backend/httpclient/http_client.go +++ b/backend/httpclient/http_client.go @@ -210,6 +210,7 @@ func DefaultMiddlewares() []Middleware { BasicAuthenticationMiddleware(), CustomHeadersMiddleware(), ContextualMiddleware(), + ErrorSourceMiddleware(), } } diff --git a/backend/httpclient/http_client_test.go b/backend/httpclient/http_client_test.go index 570fe0833..e8875d180 100644 --- a/backend/httpclient/http_client_test.go +++ b/backend/httpclient/http_client_test.go @@ -55,11 +55,12 @@ func TestNewClient(t *testing.T) { require.NoError(t, err) require.NotNil(t, client) - require.Len(t, usedMiddlewares, 4) + require.Len(t, usedMiddlewares, 5) require.Equal(t, TracingMiddlewareName, usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, usedMiddlewares[3].(MiddlewareName).MiddlewareName()) + require.Equal(t, ErrorSourceMiddlewareName, usedMiddlewares[4].(MiddlewareName).MiddlewareName()) }) t.Run("New() with opts middleware should return expected http.Client", func(t *testing.T) { diff --git a/backend/httpclient/provider_test.go b/backend/httpclient/provider_test.go index b5331321a..8d5984fb1 100644 --- a/backend/httpclient/provider_test.go +++ b/backend/httpclient/provider_test.go @@ -24,7 +24,7 @@ func TestProvider(t *testing.T) { client, err := ctx.provider.New() require.NoError(t, err) require.NotNil(t, client) - require.Len(t, ctx.usedMiddlewares, 4) + require.Len(t, ctx.usedMiddlewares, 5) require.Equal(t, TracingMiddlewareName, ctx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, ctx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, ctx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) @@ -36,7 +36,7 @@ func TestProvider(t *testing.T) { transport, err := ctx.provider.GetTransport() require.NoError(t, err) require.NotNil(t, transport) - require.Len(t, ctx.usedMiddlewares, 4) + require.Len(t, ctx.usedMiddlewares, 5) require.Equal(t, TracingMiddlewareName, ctx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, ctx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, ctx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) @@ -81,7 +81,7 @@ func TestProvider(t *testing.T) { require.Equal(t, DefaultTimeoutOptions.Timeout, client.Timeout) t.Run("Should use configured middlewares and implement MiddlewareName", func(t *testing.T) { - require.Len(t, pCtx.usedMiddlewares, 7) + require.Len(t, pCtx.usedMiddlewares, 8) require.Equal(t, "mw1", pCtx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, "mw2", pCtx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, "mw3", pCtx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) diff --git a/experimental/errorsource/error_source_middleware.go b/experimental/errorsource/error_source_middleware.go index 95ff79b10..db9ef5dfa 100644 --- a/experimental/errorsource/error_source_middleware.go +++ b/experimental/errorsource/error_source_middleware.go @@ -2,9 +2,7 @@ package errorsource import ( "errors" - "net" "net/http" - "syscall" "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" @@ -26,11 +24,7 @@ func RoundTripper(_ httpclient.Options, next http.RoundTripper) http.RoundTrippe } return res, Error{source: errorSource, err: err} } - if errors.Is(err, syscall.ECONNREFUSED) { - return res, Error{source: backend.ErrorSourceDownstream, err: err} - } - var dnsError *net.DNSError - if errors.As(err, &dnsError) && dnsError.IsNotFound { + if httpclient.IsDownstreamHTTPError(err) { return res, Error{source: backend.ErrorSourceDownstream, err: err} } return res, err