diff --git a/flyteadmin/auth/authzserver/metadata_provider.go b/flyteadmin/auth/authzserver/metadata_provider.go index b1887ef4c2..8a0acd906c 100644 --- a/flyteadmin/auth/authzserver/metadata_provider.go +++ b/flyteadmin/auth/authzserver/metadata_provider.go @@ -2,14 +2,22 @@ package authzserver import ( "context" + "fmt" "io/ioutil" "net/http" "net/url" "strings" + "time" + + "google.golang.org/grpc/codes" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/retry" "github.com/flyteorg/flyte/flyteadmin/auth" authConfig "github.com/flyteorg/flyte/flyteadmin/auth/config" + flyteErrors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flytestdlib/logger" ) type OAuth2MetadataProvider struct { @@ -72,7 +80,7 @@ func (s OAuth2MetadataProvider) GetOAuth2Metadata(ctx context.Context, r *servic httpClient.Transport = transport } - response, err := httpClient.Get(externalMetadataURL.String()) + response, err := sendAndRetryHTTPRequest(ctx, httpClient, externalMetadataURL.String(), s.cfg.AppAuth.ExternalAuthServer.RetryAttempts, s.cfg.AppAuth.ExternalAuthServer.RetryDelay.Duration) if err != nil { return nil, err } @@ -107,3 +115,41 @@ func NewService(config *authConfig.Config) OAuth2MetadataProvider { cfg: config, } } + +func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, url string, retryAttempts int, retryDelay time.Duration) (*http.Response, error) { + var response *http.Response + var err error + totalAttempts := retryAttempts + 1 // Add one for initial http request attempt + + backoff := wait.Backoff{ + Duration: retryDelay, + Steps: totalAttempts, + } + + retryableOauthMetadataError := flyteErrors.NewFlyteAdminError(codes.Internal, "Failed to get oauth metadata.") + err = retry.OnError(backoff, + func(err error) bool { // Determine if error is retryable + return err == retryableOauthMetadataError + }, func() error { // Send HTTP request + response, err = client.Get(url) + if err != nil { + logger.Errorf(ctx, "Failed to send oauth metadata HTTP request. Err: %v", err) + return err + } + if response != nil && response.StatusCode >= http.StatusUnauthorized && response.StatusCode <= http.StatusNetworkAuthenticationRequired { + logger.Errorf(ctx, "Failed to get oauth metadata, going to retry. StatusCode: %v Err: %v", response.StatusCode, err) + return retryableOauthMetadataError + } + return nil + }) + + if err != nil { + return nil, err + } + + if response != nil && response.StatusCode != http.StatusOK { + return response, fmt.Errorf("failed to get oauth metadata with status code %v", response.StatusCode) + } + + return response, nil +} diff --git a/flyteadmin/auth/authzserver/metadata_provider_test.go b/flyteadmin/auth/authzserver/metadata_provider_test.go index f1b244012e..cb8a2b5f97 100644 --- a/flyteadmin/auth/authzserver/metadata_provider_test.go +++ b/flyteadmin/auth/authzserver/metadata_provider_test.go @@ -16,6 +16,8 @@ import ( config2 "github.com/flyteorg/flyte/flytestdlib/config" ) +var oauthMetadataFailureErrorMessage = "Failed to get oauth metadata." + func TestOAuth2MetadataProvider_FlyteClient(t *testing.T) { provider := NewService(&authConfig.Config{ AppAuth: authConfig.OAuth2Options{ @@ -111,3 +113,84 @@ func TestOAuth2MetadataProvider_OAuth2Metadata(t *testing.T) { assert.Equal(t, "https://dev-14186422.okta.com", resp.Issuer) }) } + +func TestSendAndRetryHttpRequest(t *testing.T) { + t.Run("Retry into failure", func(t *testing.T) { + requestAttempts := 0 + hf := func(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSpace(r.URL.Path) { + case "/": + w.WriteHeader(500) + requestAttempts++ + default: + http.NotFoundHandler().ServeHTTP(w, r) + } + } + + server := httptest.NewServer(http.HandlerFunc(hf)) + defer server.Close() + http.DefaultClient = server.Client() + retryAttempts := 5 + totalAttempts := retryAttempts + 1 // 1 for the initial try + + resp, err := sendAndRetryHttpRequest(context.Background(), server.Client(), server.URL, retryAttempts, 0 /* for testing */) + assert.Error(t, err) + assert.Equal(t, oauthMetadataFailureErrorMessage, err.Error()) + assert.Nil(t, resp) + assert.Equal(t, totalAttempts, requestAttempts) + }) + + t.Run("Retry into success", func(t *testing.T) { + requestAttempts := 0 + hf := func(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSpace(r.URL.Path) { + case "/": + if requestAttempts > 2 { + w.WriteHeader(200) + } else { + requestAttempts++ + w.WriteHeader(500) + } + default: + http.NotFoundHandler().ServeHTTP(w, r) + } + } + + server := httptest.NewServer(http.HandlerFunc(hf)) + defer server.Close() + http.DefaultClient = server.Client() + retryAttempts := 5 + expectedRequestAttempts := 3 + + resp, err := sendAndRetryHttpRequest(context.Background(), server.Client(), server.URL, retryAttempts, 0 /* for testing */) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, expectedRequestAttempts, requestAttempts) + }) + + t.Run("Success", func(t *testing.T) { + requestAttempts := 0 + hf := func(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSpace(r.URL.Path) { + case "/": + w.WriteHeader(200) + default: + http.NotFoundHandler().ServeHTTP(w, r) + } + } + + server := httptest.NewServer(http.HandlerFunc(hf)) + defer server.Close() + http.DefaultClient = server.Client() + retryAttempts := 5 + expectedRequestAttempts := 0 + + resp, err := sendAndRetryHttpRequest(context.Background(), server.Client(), server.URL, retryAttempts, 0 /* for testing */) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, expectedRequestAttempts, requestAttempts) + }) + +} diff --git a/flyteadmin/auth/config/config.go b/flyteadmin/auth/config/config.go index 217983683e..f8c30745bb 100644 --- a/flyteadmin/auth/config/config.go +++ b/flyteadmin/auth/config/config.go @@ -79,6 +79,10 @@ var ( }, }, AppAuth: OAuth2Options{ + ExternalAuthServer: ExternalAuthorizationServer{ + RetryAttempts: 5, + RetryDelay: config.Duration{Duration: 1 * time.Second}, + }, AuthServerType: AuthorizationServerTypeSelf, ThirdParty: ThirdPartyConfigOptions{ FlyteClientConfig: FlyteClientConfig{ @@ -191,7 +195,9 @@ type ExternalAuthorizationServer struct { AllowedAudience []string `json:"allowedAudience" pflag:",Optional: A list of allowed audiences. If not provided, the audience is expected to be the public Uri of the service."` MetadataEndpointURL config.URL `json:"metadataUrl" pflag:",Optional: If the server doesn't support /.well-known/oauth-authorization-server, you can set a custom metadata url here.'"` // HTTPProxyURL allows operators to access external OAuth2 servers using an external HTTP Proxy - HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",OPTIONAL: HTTP Proxy to be used for OAuth requests."` + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",OPTIONAL: HTTP Proxy to be used for OAuth requests."` + RetryAttempts int `json:"retryAttempts" pflag:", Optional: The number of attempted retries on a transient failure to get the OAuth metadata"` + RetryDelay config.Duration `json:"retryDelay" pflag:", Optional, Duration to wait between retries"` } // OAuth2Options defines settings for app auth. diff --git a/flyteadmin/auth/config/config_flags.go b/flyteadmin/auth/config/config_flags.go index 225e8a5c9d..4012f98f5d 100755 --- a/flyteadmin/auth/config/config_flags.go +++ b/flyteadmin/auth/config/config_flags.go @@ -77,6 +77,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "appAuth.externalAuthServer.allowedAudience"), DefaultConfig.AppAuth.ExternalAuthServer.AllowedAudience, "Optional: A list of allowed audiences. If not provided, the audience is expected to be the public Uri of the service.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.externalAuthServer.metadataUrl"), DefaultConfig.AppAuth.ExternalAuthServer.MetadataEndpointURL.String(), "Optional: If the server doesn't support /.well-known/oauth-authorization-server, you can set a custom metadata url here.'") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.externalAuthServer.httpProxyURL"), DefaultConfig.AppAuth.ExternalAuthServer.HTTPProxyURL.String(), "OPTIONAL: HTTP Proxy to be used for OAuth requests.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "appAuth.externalAuthServer.retryAttempts"), DefaultConfig.AppAuth.ExternalAuthServer.RetryAttempts, " Optional: The number of attempted retries on a transient failure to get the OAuth metadata") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.externalAuthServer.retryDelay"), DefaultConfig.AppAuth.ExternalAuthServer.RetryDelay.String(), " Optional, Duration to wait between retries") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.thirdPartyConfig.flyteClient.clientId"), DefaultConfig.AppAuth.ThirdParty.FlyteClientConfig.ClientID, "public identifier for the app which handles authorization for a Flyte deployment") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.thirdPartyConfig.flyteClient.redirectUri"), DefaultConfig.AppAuth.ThirdParty.FlyteClientConfig.RedirectURI, "This is the callback uri registered with the app which handles authorization for a Flyte deployment") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "appAuth.thirdPartyConfig.flyteClient.scopes"), DefaultConfig.AppAuth.ThirdParty.FlyteClientConfig.Scopes, "Recommended scopes for the client to request.") diff --git a/flyteadmin/auth/config/config_flags_test.go b/flyteadmin/auth/config/config_flags_test.go index 28efafc380..26fe17dd0e 100755 --- a/flyteadmin/auth/config/config_flags_test.go +++ b/flyteadmin/auth/config/config_flags_test.go @@ -477,6 +477,34 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_appAuth.externalAuthServer.retryAttempts", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("appAuth.externalAuthServer.retryAttempts", testValue) + if vInt, err := cmdFlags.GetInt("appAuth.externalAuthServer.retryAttempts"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AppAuth.ExternalAuthServer.RetryAttempts) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_appAuth.externalAuthServer.retryDelay", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := DefaultConfig.AppAuth.ExternalAuthServer.RetryDelay.String() + + cmdFlags.Set("appAuth.externalAuthServer.retryDelay", testValue) + if vString, err := cmdFlags.GetString("appAuth.externalAuthServer.retryDelay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AppAuth.ExternalAuthServer.RetryDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_appAuth.thirdPartyConfig.flyteClient.clientId", func(t *testing.T) { t.Run("Override", func(t *testing.T) {