Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix verification logic #455

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions auth/authzserver/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ func TestAuthEndpoint(t *testing.T) {
})
}

// #nosec
const sampleIDToken = `eyJraWQiOiJaNmRtWl9UWGhkdXctalVCWjZ1RUV6dm5oLWpoTk8wWWhlbUI3cWFfTE9jIiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiIwMHVra2k0OHBzSDhMaWtZVjVkNiIsIm5hbWUiOiJIYXl0aGFtIEFidWVsZnV0dWgiLCJ2ZXIiOjEsImlzcyI6Imh0dHBzOi8vZGV2LTE0MTg2NDIyLm9rdGEuY29tL29hdXRoMi9hdXNrbmdubjd1QlZpUXE2YjVkNiIsImF1ZCI6IjBvYWtraGV0ZU5qQ01FUnN0NWQ2IiwiaWF0IjoxNjE4NDUzNjc5LCJleHAiOjE2MTg0NTcyNzksImp0aSI6IklELmE0YXpLdUphVFM2YzNTeHdpWWdTMHhPbTM2bVFnVlVVN0I4V2dEdk80dFkiLCJhbXIiOlsicHdkIl0sImlkcCI6IjBvYWtrbTFjaTFVZVBwTlUwNWQ2IiwicHJlZmVycmVkX3VzZXJuYW1lIjoiaGF5dGhhbUB1bmlvbi5haSIsImF1dGhfdGltZSI6MTYxODQ0NjI0NywiYXRfaGFzaCI6Ikg5Q0FweWlrQkpGYXJ4d1FUbnB6ZFEifQ.SJ3BTD_MFcrYvTnql181Ddeb_mOm81z_S7ZKQ6P8mMgWqn94LZ2nG8k8-_odaaNAAT-M1nAFKWqZAQGvliwS1_TsD8_j0cen5zYnGcz2Uu5fFlvoHwuPgy5JYYNOXkXYgPnIb3kNkgXKbkdjS9hdbMfvnPd9rr8v0yzqf0AQBnUe-cPrzY-ZJjvh80IWDZgSjoP244tTYppPkx8UtedJLJZ4tzB7aXlEyoRV-DpmOLfJkAmblRm4OsO1qjwmx3HSIy_T-0PANn-g4AS07rpoMYHRcqncdgcAsVfGxjyWiOg3kbymLqpGlkIZgzmev-TmpoDp0QkUVPOntuiB57GZ6g`

//func TestAuthCallbackEndpoint(t *testing.T) {
// originalURL := "http://localhost:8088/oauth2/authorize?client_id=my-client&redirect_uri=http%3A%2F%2Flocalhost%3A3846%2Fcallback&response_type=code&scope=photos+openid+offline&state=some-random-state-foobar&nonce=some-random-nonce&code_challenge=p0v_UR0KrXl4--BpxM2BQa7qIW5k3k4WauBhjmkVQw8&code_challenge_method=S256"
// req := httptest.NewRequest(http.MethodGet, originalURL, nil)
Expand Down
15 changes: 10 additions & 5 deletions auth/authzserver/resource_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"k8s.io/apimachinery/pkg/util/sets"

"github.com/flyteorg/flytestdlib/config"
jwtgo "github.com/golang-jwt/jwt/v4"

"github.com/coreos/go-oidc"
authConfig "github.com/flyteorg/flyteadmin/auth/config"
Expand All @@ -28,17 +29,21 @@ type ResourceServer struct {
}

func (r ResourceServer) ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (interfaces.IdentityContext, error) {
raw, err := r.signatureVerifier.VerifySignature(ctx, tokenStr)
_, err := r.signatureVerifier.VerifySignature(ctx, tokenStr)
if err != nil {
return nil, err
}

claimsRaw := map[string]interface{}{}
if err = json.Unmarshal(raw, &claimsRaw); err != nil {
return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err)
t, _, err := jwtgo.NewParser().ParseUnverified(tokenStr, jwtgo.MapClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %v", err)
}

if err = t.Claims.Valid(); err != nil {
return nil, fmt.Errorf("failed to validate token: %v", err)
}

return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), claimsRaw)
return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), t.Claims.(jwtgo.MapClaims))
}

func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
Expand Down
111 changes: 96 additions & 15 deletions auth/authzserver/resource_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package authzserver

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
Expand All @@ -10,6 +12,9 @@ import (
"reflect"
"strings"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"

"github.com/stretchr/testify/assert"

Expand All @@ -21,20 +26,20 @@ import (
stdlibConfig "github.com/flyteorg/flytestdlib/config"
)

func newMockResourceServer(t testing.TB) ResourceServer {
func newMockResourceServer(t testing.TB, publicKey rsa.PublicKey) (resourceServer ResourceServer, closer func()) {
ctx := context.Background()
dummy := ""
serverURL := &dummy
hf := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/oauth-authorization-server" {
w.Header().Set("Content-Type", "application/json")
_, err := io.WriteString(w, strings.ReplaceAll(`{
"issuer": "https://dev-14186422.okta.com",
"issuer": "https://whatever.okta.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "URL/keys",
"jwks_uri": "{URL}/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, "URL", *serverURL))
}`, "{URL}", *serverURL))

if !assert.NoError(t, err) {
t.FailNow()
Expand All @@ -43,6 +48,14 @@ func newMockResourceServer(t testing.TB) ResourceServer {
return
} else if r.URL.Path == "/keys" {
keys := jwk.NewSet()
key := jwk.NewRSAPublicKey()
err := key.FromRaw(&publicKey)
if err != nil {
http.Error(w, err.Error(), 400)
return
}

keys.Add(key)
raw, err := json.Marshal(keys)
if err != nil {
http.Error(w, err.Error(), 400)
Expand All @@ -55,36 +68,104 @@ func newMockResourceServer(t testing.TB) ResourceServer {
if !assert.NoError(t, err) {
t.FailNow()
}

return
}

http.NotFound(w, r)
}

s := httptest.NewServer(http.HandlerFunc(hf))
defer s.Close()

*serverURL = s.URL

http.DefaultClient = s.Client()

r, err := NewOAuth2ResourceServer(ctx, authConfig.ExternalAuthorizationServer{
BaseURL: stdlibConfig.URL{URL: *config.MustParseURL(s.URL)},
BaseURL: stdlibConfig.URL{URL: *config.MustParseURL(s.URL)},
AllowedAudience: []string{"https://localhost"},
}, stdlibConfig.URL{})
if !assert.NoError(t, err) {
t.FailNow()
}

return r
}

func TestNewOAuth2ResourceServer(t *testing.T) {
newMockResourceServer(t)
return r, func() {
s.Close()
}
}

func TestResourceServer_ValidateAccessToken(t *testing.T) {
r := newMockResourceServer(t)
_, err := r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
assert.Error(t, err)
sampleRSAKey, err := rsa.GenerateKey(rand.Reader, 2048)
if !assert.NoError(t, err) {
t.FailNow()
}

r, closer := newMockResourceServer(t, sampleRSAKey.PublicKey)
defer closer()

t.Run("Invalid signature", func(t *testing.T) {
sampleRSAKey, err := rsa.GenerateKey(rand.Reader, 2048)
if !assert.NoError(t, err) {
t.FailNow()
}

sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.StandardClaims{
Audience: r.allowedAudience[0],
ExpiresAt: time.Now().Add(time.Hour).Unix(),
IssuedAt: time.Now().Unix(),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "failed to verify id token signature")
})

t.Run("Invalid audience", func(t *testing.T) {
sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.StandardClaims{
Audience: "https://hello world",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
IssuedAt: time.Now().Unix(),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "invalid audience")
})

t.Run("Expired token", func(t *testing.T) {
sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.StandardClaims{
Audience: r.allowedAudience[0],
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
IssuedAt: time.Now().Add(-2 * time.Hour).Unix(),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "failed to validate token: Token is expired")
})
}

func Test_doRequest(t *testing.T) {
Expand Down