diff --git a/auth/authzserver/provider.go b/auth/authzserver/provider.go index 19b4e9d61..ebd79bddc 100644 --- a/auth/authzserver/provider.go +++ b/auth/authzserver/provider.go @@ -146,15 +146,6 @@ func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{} if foundAudIndex < 0 { return nil, fmt.Errorf("invalid audience [%v]", claims) } - // - //if expiryClaim, found := claimsRaw[ExpiryClaim]; !found { - // return nil, fmt.Errorf("missing expiry claim") - //} else { - // expiry := expiryClaim.(float64) - // if expiry < float64(time.Now().Unix()) { - // return nil, fmt.Errorf("token has expired") - // } - //} userInfo := &service.UserInfoResponse{} if userInfoClaim, found := claimsRaw[UserIDClaim]; found && userInfoClaim != nil { diff --git a/auth/authzserver/resource_server.go b/auth/authzserver/resource_server.go index 78e895297..5c609f46c 100644 --- a/auth/authzserver/resource_server.go +++ b/auth/authzserver/resource_server.go @@ -13,6 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/sets" "github.com/flyteorg/flytestdlib/config" + jwtgo "github.com/golang-jwt/jwt/v4" "github.com/coreos/go-oidc" authConfig "github.com/flyteorg/flyteadmin/auth/config" @@ -28,17 +29,21 @@ type ResourceServer struct { } func (r ResourceServer) ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (interfaces.IdentityContext, error) { - raw, err := r.signatureVerifier.VerifySignature(ctx, tokenStr) + _, err := r.signatureVerifier.VerifySignature(ctx, tokenStr) if err != nil { return nil, err } - claimsRaw := map[string]interface{}{} - if err = json.Unmarshal(raw, &claimsRaw); err != nil { - return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err) + t, _, err := jwtgo.NewParser().ParseUnverified(tokenStr, jwtgo.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %v", err) + } + + if err = t.Claims.Valid(); err != nil { + return nil, fmt.Errorf("failed to validate token: %v", err) } - return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), claimsRaw) + return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), t.Claims.(jwtgo.MapClaims)) } func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) { diff --git a/auth/authzserver/resource_server_test.go b/auth/authzserver/resource_server_test.go index 38a0d0453..e17027440 100644 --- a/auth/authzserver/resource_server_test.go +++ b/auth/authzserver/resource_server_test.go @@ -169,7 +169,7 @@ func TestResourceServer_ValidateAccessToken(t *testing.T) { t.FailNow() } - assert.Contains(t, err.Error(), "failed to verify id token signature") + assert.Contains(t, err.Error(), "failed to validate token: Token is expired") }) }