Skip to content

Commit

Permalink
Initial refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed Nov 4, 2024
1 parent 41df2f4 commit be93274
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 74 deletions.
3 changes: 1 addition & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ func (c *DatabricksClient) GetOAuthToken(ctx context.Context, authDetails string

// Do sends an HTTP request against path.
func (c *DatabricksClient) Do(ctx context.Context, method, path string,
headers map[string]string, request, response any,
visitors ...func(*http.Request) error) error {
headers map[string]string, request, response any, visitors ...func(*http.Request) error) error {
opts := []httpclient.DoOption{}
for _, v := range visitors {
opts = append(opts, httpclient.WithRequestVisitor(v))
Expand Down
22 changes: 12 additions & 10 deletions config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/common"
"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/useragent"
Expand Down Expand Up @@ -73,17 +74,18 @@ func (c *Config) NewApiClient() (*httpclient.ApiClient, error) {
return nil
},
},
TransientErrors: []string{
"REQUEST_LIMIT_EXCEEDED", // This is temporary workaround for SCIM API returning 500. Remove when it's fixed
},
ErrorMapper: apierr.GetAPIError,
ErrorRetriable: func(ctx context.Context, err error) bool {
var apiErr *apierr.APIError
if errors.As(err, &apiErr) {
return apiErr.IsRetriable(ctx)
}
return false
},
ErrorRetriable: httpclient.CombineRetriers(
func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool {
var apiErr *apierr.APIError
if errors.As(err, &apiErr) {
return apiErr.IsRetriable(ctx)
}
return false
},
httpclient.RetryUrlErrors,
httpclient.RetryTransientErrors([]string{"REQUEST_LIMIT_EXCEEDED"}),
),
}), nil
}

Expand Down
17 changes: 10 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,16 @@ func (c *Config) EnsureResolved() error {
HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second,
Transport: c.HTTPTransport,
ErrorMapper: c.refreshTokenErrorMapper,
TransientErrors: []string{
"throttled",
"too many requests",
"429",
"request limit exceeded",
"rate limit",
},
ErrorRetriable: httpclient.CombineRetriers(
httpclient.DefaultErrorRetriable,
httpclient.RetryTransientErrors([]string{
"throttled",
"too many requests",
"429",
"request limit exceeded",
"rate limit",
}),
),
})
if c.azureTenantIdFetchClient == nil {
c.azureTenantIdFetchClient = &http.Client{
Expand Down
57 changes: 12 additions & 45 deletions httpclient/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"net/url"
"runtime"
"strings"
"time"

"github.com/databricks/databricks-sdk-go/common"
Expand All @@ -35,9 +34,8 @@ type ClientConfig struct {
DebugTruncateBytes int
RateLimitPerSecond int

ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error
ErrorRetriable func(ctx context.Context, err error) bool
TransientErrors []string
ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error
ErrorRetriable ErrorRetryer

Transport http.RoundTripper
}
Expand Down Expand Up @@ -130,7 +128,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio
// merge client-wide and request-specific visitors
visitors = append(visitors, o.in)
}

}
// Use default AuthVisitor if none is provided
if authVisitor == nil {
Expand Down Expand Up @@ -170,45 +167,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio
return nil
}

func (c *ApiClient) isRetriable(ctx context.Context, err error) bool {
if c.config.ErrorRetriable(ctx, err) {
return true
}
if isRetriableUrlError(err) {
// all IO errors are retriable
logger.Debugf(ctx, "Attempting retry because of IO error: %s", err)
return true
}
message := err.Error()
// Handle transient errors for retries
for _, substring := range c.config.TransientErrors {
if strings.Contains(message, substring) {
logger.Debugf(ctx, "Attempting retry because of %#v", substring)
return true
}
}
// some API's recommend retries on HTTP 500, but we'll add that later
return false
}

// Common error-handling logic for all responses that may need to be retried.
//
// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed
// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt.
//
// Always returns nil for the first parameter as there is no meaningful response body to return in the error case.
//
// If it is certain that an error should not be retried, use failRequest() instead.
func (c *ApiClient) handleError(ctx context.Context, err error, body common.RequestBody) (*common.ResponseWrapper, *retries.Err) {
if !c.isRetriable(ctx, err) {
return nil, retries.Halt(err)
}
if resetErr := body.Reset(); resetErr != nil {
return nil, retries.Halt(resetErr)
}
return nil, retries.Continue(err)
}

// Fails the request with a retries.Err to halt future retries.
func (c *ApiClient) failRequest(msg string, err error) (*common.ResponseWrapper, *retries.Err) {
err = fmt.Errorf("%s: %w", msg, err)
Expand Down Expand Up @@ -299,7 +257,16 @@ func (c *ApiClient) attempt(

// proactively release the connections in HTTP connection pool
c.httpClient.CloseIdleConnections()
return c.handleError(ctx, err, requestBody)

// Non-retriable errors can be returned immediately.
if !c.config.ErrorRetriable(ctx, request, &responseWrapper, err) {
return nil, retries.Halt(err)
}
// Retriable errors may require the request body to be reset.
if resetErr := requestBody.Reset(); resetErr != nil {
return nil, retries.Halt(resetErr)
}
return nil, retries.Continue(err)
}
}

Expand Down
58 changes: 48 additions & 10 deletions httpclient/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/databricks/databricks-sdk-go/common"
"github.com/databricks/databricks-sdk-go/logger"
)

type HttpError struct {
Expand Down Expand Up @@ -45,17 +46,39 @@ func DefaultErrorMapper(ctx context.Context, resp common.ResponseWrapper) error
}
}

func DefaultErrorRetriable(ctx context.Context, err error) bool {
var httpError *HttpError
if errors.As(err, &httpError) {
if httpError.StatusCode == http.StatusTooManyRequests {
return true
}
if httpError.StatusCode == http.StatusGatewayTimeout {
return true
type ErrorRetryer func(context.Context, *http.Request, *common.ResponseWrapper, error) bool

func DefaultErrorRetriable(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool {
return CombineRetriers(
RetryOnTooManyRequests,
RetryOnGatewayTimeout,
RetryUrlErrors,
)(ctx, req, resp, err)
}

func RetryOnTooManyRequests(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool {
if resp.Response == nil {
return false
}
return resp.Response.StatusCode == http.StatusTooManyRequests
}

func RetryOnGatewayTimeout(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool {
if resp.Response == nil {
return false
}
return resp.Response.StatusCode == http.StatusGatewayTimeout
}

func CombineRetriers(retriers ...ErrorRetryer) ErrorRetryer {
return func(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool {
for _, retrier := range retriers {
if retrier(ctx, req, resp, err) {
return true
}
}
return false
}
return false
}

var urlErrorTransientErrorMessages = []string{
Expand All @@ -66,15 +89,30 @@ var urlErrorTransientErrorMessages = []string{
"i/o timeout",
}

func isRetriableUrlError(err error) bool {
func RetryUrlErrors(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool {
var urlError *url.Error
if !errors.As(err, &urlError) {
return false
}
for _, msg := range urlErrorTransientErrorMessages {
if strings.Contains(err.Error(), msg) {
logger.Debugf(ctx, "Attempting retry because of IO error: %s", err)
return true
}
}
return false
}

func RetryTransientErrors(errors []string) ErrorRetryer {
return func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool {
message := err.Error()
// Handle transient errors for retries
for _, substring := range errors {
if strings.Contains(message, substring) {
logger.Debugf(ctx, "Attempting retry because of %#v", substring)
return true
}
}
return false
}
}

0 comments on commit be93274

Please sign in to comment.