From b661697c2e015c89a61845b20f74002dd9d0d744 Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Tue, 6 Feb 2024 14:47:37 -0800 Subject: [PATCH] [FLYTE-486] Support selecting IDP based on the query parameter (#4838) * Added config option for IDPQuery parameter Signed-off-by: pmahindrakar-oss * Using query.values Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss --------- Signed-off-by: pmahindrakar-oss --- flyteadmin/auth/config/config.go | 1 + flyteadmin/auth/config/config_flags.go | 1 + flyteadmin/auth/config/config_flags_test.go | 14 ++++++ flyteadmin/auth/handlers.go | 21 ++++++++- flyteadmin/auth/handlers_test.go | 49 +++++++++++++++++---- 5 files changed, 76 insertions(+), 10 deletions(-) diff --git a/flyteadmin/auth/config/config.go b/flyteadmin/auth/config/config.go index f8c30745bb..f96c5cf0ae 100644 --- a/flyteadmin/auth/config/config.go +++ b/flyteadmin/auth/config/config.go @@ -233,6 +233,7 @@ type UserAuthConfig struct { CookieHashKeySecretName string `json:"cookieHashKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie hash key."` CookieBlockKeySecretName string `json:"cookieBlockKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie block key."` CookieSetting CookieSettings `json:"cookieSetting" pflag:", settings used by cookies created for user auth"` + IDPQueryParameter string `json:"idpQueryParameter" pflag:", idp query parameter used for selecting a particular IDP for doing user authentication. Eg: for Okta passing idp= forces the authentication to happen with IDP-ID"` } //go:generate enumer --type=SameSite --trimprefix=SameSite -json diff --git a/flyteadmin/auth/config/config_flags.go b/flyteadmin/auth/config/config_flags.go index 4012f98f5d..b95beb23f3 100755 --- a/flyteadmin/auth/config/config_flags.go +++ b/flyteadmin/auth/config/config_flags.go @@ -66,6 +66,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieBlockKeySecretName"), DefaultConfig.UserAuth.CookieBlockKeySecretName, "OPTIONAL: Secret name to use for cookie block key.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieSetting.sameSitePolicy"), DefaultConfig.UserAuth.CookieSetting.SameSitePolicy.String(), "OPTIONAL: Allows you to declare if your cookie should be restricted to a first-party or same-site context.Wrapper around http.SameSite.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieSetting.domain"), DefaultConfig.UserAuth.CookieSetting.Domain, "OPTIONAL: Allows you to set the domain attribute on the auth cookies.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.idpQueryParameter"), DefaultConfig.UserAuth.IDPQueryParameter, " idp query parameter used for selecting a particular IDP for doing user authentication. Eg: for Okta passing idp= forces the authentication to happen with IDP-ID") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.selfAuthServer.issuer"), DefaultConfig.AppAuth.SelfAuthServer.Issuer, "Defines the issuer to use when issuing and validating tokens. The default value is https:///") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.selfAuthServer.accessTokenLifespan"), DefaultConfig.AppAuth.SelfAuthServer.AccessTokenLifespan.String(), "Defines the lifespan of issued access tokens.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.selfAuthServer.refreshTokenLifespan"), DefaultConfig.AppAuth.SelfAuthServer.RefreshTokenLifespan.String(), "Defines the lifespan of issued access tokens.") diff --git a/flyteadmin/auth/config/config_flags_test.go b/flyteadmin/auth/config/config_flags_test.go index 26fe17dd0e..25db81d2d3 100755 --- a/flyteadmin/auth/config/config_flags_test.go +++ b/flyteadmin/auth/config/config_flags_test.go @@ -323,6 +323,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_userAuth.idpQueryParameter", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("userAuth.idpQueryParameter", testValue) + if vString, err := cmdFlags.GetString("userAuth.idpQueryParameter"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.UserAuth.IDPQueryParameter) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_appAuth.selfAuthServer.issuer", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/flyteadmin/auth/handlers.go b/flyteadmin/auth/handlers.go index 2aa232aadc..bb03e1a654 100644 --- a/flyteadmin/auth/handlers.go +++ b/flyteadmin/auth/handlers.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "strings" "time" @@ -139,7 +140,7 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte state := HashCsrfState(csrfToken) logger.Debugf(ctx, "Setting CSRF state cookie to %s and state to %s\n", csrfToken, state) - url := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).AuthCodeURL(state) + urlString := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).AuthCodeURL(state) queryParams := request.URL.Query() if !GetRedirectURLAllowed(ctx, queryParams.Get(RedirectURLParameter), authCtx.Options()) { logger.Infof(ctx, "unauthorized redirect URI") @@ -154,7 +155,23 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte logger.Errorf(ctx, "Was not able to create a redirect cookie") } } - http.Redirect(writer, request, url, http.StatusTemporaryRedirect) + + idpURL, err := url.Parse(urlString) + if err != nil { + logger.Errorf(ctx, "failed to parse url %q: %v", urlString, err) + writer.WriteHeader(http.StatusInternalServerError) + } + + // Add the IDPQueryParameter to the URL if it is present in the request + idpQueryParam := authCtx.Options().UserAuth.IDPQueryParameter + if len(idpQueryParam) > 0 && queryParams.Get(idpQueryParam) != "" { + logger.Infof(ctx, "Adding IDP Query Parameter to the URL") + query := idpURL.Query() // Gets a copy of query parameters + query.Add(idpQueryParam, queryParams.Get(idpQueryParam)) + // Updates the rawquery with the new query parameters + idpURL.RawQuery = query.Encode() + } + http.Redirect(writer, request, idpURL.String(), http.StatusTemporaryRedirect) } } diff --git a/flyteadmin/auth/handlers_test.go b/flyteadmin/auth/handlers_test.go index 90822cac0d..452f797d9f 100644 --- a/flyteadmin/auth/handlers_test.go +++ b/flyteadmin/auth/handlers_test.go @@ -245,16 +245,49 @@ func TestGetLoginHandler(t *testing.T) { Scopes: []string{"openid", "other"}, } mockAuthCtx := mocks.AuthenticationContext{} - mockAuthCtx.OnOptions().Return(&config.Config{}) + mockAuthCtx.OnOptions().Return(&config.Config{ + UserAuth: config.UserAuthConfig{ + IDPQueryParameter: "idp", + }, + }) mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config) handler := GetLoginHandler(ctx, &mockAuthCtx) - req, err := http.NewRequest("GET", "/login", nil) - assert.NoError(t, err) - w := httptest.NewRecorder() - handler(w, req) - assert.Equal(t, 307, w.Code) - assert.True(t, strings.Contains(w.Header().Get("Location"), "response_type=code&scope=openid+other")) - assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), "flyte_csrf_state=")) + + type test struct { + name string + url string + expectedStatusCode int + expectedLocation string + expectedSetCookie string + } + tests := []test{ + { + name: "no idp parameter", + url: "/login", + expectedStatusCode: 307, + expectedLocation: "response_type=code&scope=openid+other", + expectedSetCookie: "flyte_csrf_state=", + }, + { + name: "with idp parameter config", + url: "/login?idp=dummyIDP", + expectedStatusCode: 307, + expectedLocation: "dummyIDP", + expectedSetCookie: "flyte_csrf_state=", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest("GET", tt.url, nil) + assert.NoError(t, err) + w := httptest.NewRecorder() + handler(w, req) + assert.Equal(t, tt.expectedStatusCode, w.Code) + assert.True(t, strings.Contains(w.Header().Get("Location"), tt.expectedLocation)) + assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), tt.expectedSetCookie)) + }) + } } func TestGetLogoutHandler(t *testing.T) {