diff --git a/flyteadmin/auth/authzserver/metadata_provider.go b/flyteadmin/auth/authzserver/metadata_provider.go index 8a0acd906c..23095b0ea6 100644 --- a/flyteadmin/auth/authzserver/metadata_provider.go +++ b/flyteadmin/auth/authzserver/metadata_provider.go @@ -31,14 +31,14 @@ func (s OAuth2MetadataProvider) AuthFuncOverride(ctx context.Context, fullMethod } func (s OAuth2MetadataProvider) GetOAuth2Metadata(ctx context.Context, r *service.OAuth2MetadataRequest) (*service.OAuth2MetadataResponse, error) { + publicURL := auth.GetPublicURL(ctx, nil, s.cfg) switch s.cfg.AppAuth.AuthServerType { case authConfig.AuthorizationServerTypeSelf: - u := auth.GetPublicURL(ctx, nil, s.cfg) doc := &service.OAuth2MetadataResponse{ Issuer: GetIssuer(ctx, nil, s.cfg), - AuthorizationEndpoint: u.ResolveReference(authorizeRelativeURL).String(), - TokenEndpoint: u.ResolveReference(tokenRelativeURL).String(), - JwksUri: u.ResolveReference(jsonWebKeysURL).String(), + AuthorizationEndpoint: publicURL.ResolveReference(authorizeRelativeURL).String(), + TokenEndpoint: publicURL.ResolveReference(tokenRelativeURL).String(), + JwksUri: publicURL.ResolveReference(jsonWebKeysURL).String(), CodeChallengeMethodsSupported: []string{"S256"}, ResponseTypesSupported: []string{ "code", @@ -96,6 +96,18 @@ func (s OAuth2MetadataProvider) GetOAuth2Metadata(ctx context.Context, r *servic return nil, err } + if len(s.cfg.TokenEndpointProxyPath) > 0 { + tokenEndpoint, err := url.Parse(resp.TokenEndpoint) + if err != nil { + return nil, flyteErrors.NewFlyteAdminError(codes.Internal, fmt.Sprintf("Failed to parse token endpoint [%v], err: %v", resp.TokenEndpoint, err)) + } + + tokenEndpoint.Host = publicURL.Host + tokenEndpoint.Path = s.cfg.TokenEndpointProxyPath + tokenEndpoint.Path + tokenEndpoint.RawPath = s.cfg.TokenEndpointProxyPath + tokenEndpoint.RawPath + resp.TokenEndpoint = tokenEndpoint.String() + } + return resp, nil } } diff --git a/flyteadmin/auth/authzserver/metadata_provider_test.go b/flyteadmin/auth/authzserver/metadata_provider_test.go index c8f92fe8cc..e7ba7088a6 100644 --- a/flyteadmin/auth/authzserver/metadata_provider_test.go +++ b/flyteadmin/auth/authzserver/metadata_provider_test.go @@ -91,6 +91,28 @@ func TestOAuth2MetadataProvider_OAuth2Metadata(t *testing.T) { ctx := context.Background() resp, err := provider.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{}) assert.NoError(t, err) + assert.Equal(t, "https://example.com/auth", resp.AuthorizationEndpoint) + assert.Equal(t, "https://example.com/token", resp.TokenEndpoint) + assert.Equal(t, "https://dev-14186422.okta.com", resp.Issuer) + }) + + t.Run("External AuthServer with proxy", func(t *testing.T) { + provider := NewService(&authConfig.Config{ + AuthorizedURIs: []config2.URL{{URL: *config.MustParseURL("https://issuer/")}}, + AppAuth: authConfig.OAuth2Options{ + AuthServerType: authConfig.AuthorizationServerTypeExternal, + ExternalAuthServer: authConfig.ExternalAuthorizationServer{ + BaseURL: config2.URL{URL: *config.MustParseURL(s.URL)}, + }, + }, + TokenEndpointProxyPath: "/my-proxy", + }) + + ctx := context.Background() + resp, err := provider.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{}) + assert.NoError(t, err) + assert.Equal(t, "https://example.com/auth", resp.AuthorizationEndpoint) + assert.Equal(t, "https://issuer/my-proxy/token", resp.TokenEndpoint) assert.Equal(t, "https://dev-14186422.okta.com", resp.Issuer) }) diff --git a/flyteadmin/auth/config/config.go b/flyteadmin/auth/config/config.go index f96c5cf0ae..88ebc8f2c1 100644 --- a/flyteadmin/auth/config/config.go +++ b/flyteadmin/auth/config/config.go @@ -164,6 +164,9 @@ type Config struct { // AppAuth settings used to authenticate and control/limit access scopes for apps. AppAuth OAuth2Options `json:"appAuth" pflag:",Defines Auth options for apps. UserAuth must be enabled for AppAuth to work."` + + // TokenEndpointProxyPath, if set, configures admin to proxy calls to the TokenURL using this path prefix. + TokenEndpointProxyPath string `json:"tokenEndpointProxyPath" pflag:",The path used to proxy calls to the TokenURL"` } type AuthorizationServer struct { diff --git a/flyteadmin/auth/config/config_flags.go b/flyteadmin/auth/config/config_flags.go index b95beb23f3..fd0248d1f8 100755 --- a/flyteadmin/auth/config/config_flags.go +++ b/flyteadmin/auth/config/config_flags.go @@ -84,5 +84,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.thirdPartyConfig.flyteClient.redirectUri"), DefaultConfig.AppAuth.ThirdParty.FlyteClientConfig.RedirectURI, "This is the callback uri registered with the app which handles authorization for a Flyte deployment") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "appAuth.thirdPartyConfig.flyteClient.scopes"), DefaultConfig.AppAuth.ThirdParty.FlyteClientConfig.Scopes, "Recommended scopes for the client to request.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.thirdPartyConfig.flyteClient.audience"), DefaultConfig.AppAuth.ThirdParty.FlyteClientConfig.Audience, "Audience to use when initiating OAuth2 authorization requests.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenEndpointProxyPath"), DefaultConfig.TokenEndpointProxyPath, "The path used to proxy calls to the TokenURL") return cmdFlags } diff --git a/flyteadmin/auth/config/config_flags_test.go b/flyteadmin/auth/config/config_flags_test.go index 25db81d2d3..adba6bd4e0 100755 --- a/flyteadmin/auth/config/config_flags_test.go +++ b/flyteadmin/auth/config/config_flags_test.go @@ -575,4 +575,18 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_tokenEndpointProxyPath", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("tokenEndpointProxyPath", testValue) + if vString, err := cmdFlags.GetString("tokenEndpointProxyPath"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.TokenEndpointProxyPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) }