Skip to content
This repository has been archived by the owner on Nov 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #32 from puppetlabs/bugs/id-token-refresh-revalida…
Browse files Browse the repository at this point in the history
…tion

Only revalidate ID tokens if provided on refresh
  • Loading branch information
impl authored Jan 19, 2021
2 parents 7ec3146 + 0c6c739 commit a98a1e2
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 32 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Versioning](https://semver.org/spec/v2.0.0.html).
refresh token flow if the plugin user specifies a nonce to validate against;
otherwise, it is assumed that the nonce data is invalid or non-conforming to
the OpenID Connect Core specification.
* Per the OpenID Connect Core specification, ID tokens will only be revalidated
during refresh if the server sends a new ID token. Otherwise, they are passed
through unmodified from the original exchange.

### Changed

Expand Down
104 changes: 72 additions & 32 deletions pkg/provider/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ type oidcOperations struct {
extraDataFields []string
}

func (oo *oidcOperations) verifyUpdateToken(ctx context.Context, t *Token, nonce string) error {
func (oo *oidcOperations) verifyUpdateIDToken(ctx context.Context, t *Token, nonce string) error {
rawIDToken, ok := t.Extra("id_token").(string)
if !ok {
if !ok || rawIDToken == "" {
return ErrOIDCMissingIDToken
}

Expand All @@ -54,34 +54,50 @@ func (oo *oidcOperations) verifyUpdateToken(ctx context.Context, t *Token, nonce
return ErrOIDCNonceMismatch
}

if len(oo.extraDataFields) > 0 {
t.ExtraData = make(map[string]interface{})

for _, field := range oo.extraDataFields {
switch field {
case oidcExtraDataFieldIDToken:
t.ExtraData[field] = rawIDToken
case oidcExtraDataFieldIDTokenClaims:
claims := make(map[string]interface{})
if err := idToken.Claims(&claims); err != nil {
return fmt.Errorf("oidc: error parsing token claims: %w", err)
}

t.ExtraData[field] = claims
case oidcExtraDataFieldUserInfo:
userInfo, err := oo.p.UserInfo(ctx, oo.basicOperations.base.TokenSource(ctx, t.Token))
if err != nil {
return fmt.Errorf("oidc: error fetching user info: %w", err)
}

claims := make(map[string]interface{})
if err := userInfo.Claims(&claims); err != nil {
return fmt.Errorf("oidc: error parsing user info: %w", err)
}

t.ExtraData[field] = claims
for _, field := range oo.extraDataFields {
switch field {
case oidcExtraDataFieldIDToken:
t.ExtraData[field] = rawIDToken
case oidcExtraDataFieldIDTokenClaims:
claims := make(map[string]interface{})
if err := idToken.Claims(&claims); err != nil {
return fmt.Errorf("oidc: error parsing token claims: %w", err)
}

t.ExtraData[field] = claims
}
}

return nil
}

func (oo *oidcOperations) copyIDToken(ctx context.Context, p, n *Token) {
for _, field := range oo.extraDataFields {
switch field {
case oidcExtraDataFieldIDToken, oidcExtraDataFieldIDTokenClaims:
n.ExtraData[field] = p.ExtraData[field]
}
}
}

func (oo *oidcOperations) updateUserInfo(ctx context.Context, t *Token) error {
for _, field := range oo.extraDataFields {
if field != oidcExtraDataFieldUserInfo {
continue
}

userInfo, err := oo.p.UserInfo(ctx, oauth2.StaticTokenSource(t.Token))
if err != nil {
return fmt.Errorf("oidc: error fetching user info: %w", err)
}

claims := make(map[string]interface{})
if err := userInfo.Claims(&claims); err != nil {
return fmt.Errorf("oidc: error parsing user info: %w", err)
}

t.ExtraData[field] = claims
break
}

return nil
Expand All @@ -96,7 +112,15 @@ func (oo *oidcOperations) AuthCodeExchange(ctx context.Context, code string, opt
return nil, err
}

if err := oo.verifyUpdateToken(ctx, t, o.ProviderOptions["nonce"]); err != nil {
if t.ExtraData == nil {
t.ExtraData = make(map[string]interface{})
}

if err := oo.verifyUpdateIDToken(ctx, t, o.ProviderOptions["nonce"]); err != nil {
return nil, errmark.MarkUser(err)
}

if err := oo.updateUserInfo(ctx, t); err != nil {
return nil, errmark.MarkUser(err)
}

Expand All @@ -107,16 +131,32 @@ func (oo *oidcOperations) RefreshToken(ctx context.Context, t *Token, opts ...Re
o := &RefreshTokenOptions{}
o.ApplyOptions(opts)

t, err := oo.basicOperations.RefreshToken(ctx, t, opts...)
nt, err := oo.basicOperations.RefreshToken(ctx, t, opts...)
if err != nil {
return nil, err
}

if err := oo.verifyUpdateToken(ctx, t, o.ProviderOptions["nonce"]); err != nil {
if nt.ExtraData == nil {
nt.ExtraData = make(map[string]interface{})
}

// Per OpenID Connect Core 1.0 § 12.2
// (https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse),
// providing an ID token as part of a refresh is optional. We will only
// revalidate the token if a new one is provided.
if rawIDToken, ok := nt.Extra("id_token").(string); ok && rawIDToken != "" {
if err := oo.verifyUpdateIDToken(ctx, nt, o.ProviderOptions["nonce"]); err != nil {
return nil, errmark.MarkUser(err)
}
} else {
oo.copyIDToken(ctx, t, nt)
}

if err := oo.updateUserInfo(ctx, t); err != nil {
return nil, errmark.MarkUser(err)
}

return t, nil
return nt, nil
}

type oidc struct {
Expand Down
204 changes: 204 additions & 0 deletions pkg/provider/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,207 @@ func TestOIDCFlow(t *testing.T) {
assert.Equal(t, "test-user", userInfo["sub"])
assert.Equal(t, "[email protected]", userInfo["email"])
}

func TestOIDCRefreshWithIDToken(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: privateKey,
}, (&jose.SignerOptions{}).WithType("JWT"))
require.NoError(t, err)

h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
_, _ = io.WriteString(w, testOIDCConfiguration)
case "/.well-known/jwks.json":
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: &privateKey.PublicKey,
KeyID: "key",
Use: "sig",
},
},
})
case "/token":
b, err := ioutil.ReadAll(r.Body)
require.NoError(t, err)

data, err := url.ParseQuery(string(b))
require.NoError(t, err)

resp := make(url.Values)

idClaims := jwt.Claims{
Issuer: "http://localhost",
Audience: jwt.Audience{"foo"},
Subject: "test-user",
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}

idToken, err := jwt.Signed(signer).
Claims(idClaims).
Claims(map[string]interface{}{"grant_type": data.Get("grant_type")}).
CompactSerialize()
require.NoError(t, err)

resp.Set("token_type", "bearer")
resp.Set("id_token", idToken)

switch data.Get("grant_type") {
case "authorization_code":
assert.Equal(t, "foo", data.Get("client_id"))
assert.Equal(t, "bar", data.Get("client_secret"))

resp.Set("access_token", "abcd")
resp.Set("refresh_token", "efgh")
resp.Set("expires_in", "1")
case "refresh_token":
resp.Set("access_token", "ijkl")
resp.Set("refresh_token", "mnop")
resp.Set("expires_in", "900")
default:
assert.Fail(t, "unexpected grant type %q", data.Get("grant_type"))
}

_, _ = io.WriteString(w, resp.Encode())
default:
w.WriteHeader(http.StatusNotFound)
}
})
c := &http.Client{Transport: &testutil.MockRoundTripper{Handler: h}}
ctx = context.WithValue(ctx, oauth2.HTTPClient, c)

oidcTest, err := provider.GlobalRegistry.New(ctx, "oidc", map[string]string{
"issuer_url": "http://localhost",
"extra_data_fields": "id_token,id_token_claims",
})
require.NoError(t, err)

ops := oidcTest.Private("foo", "bar")

token, err := ops.AuthCodeExchange(ctx, "123456")
require.NoError(t, err)
require.NotNil(t, token)
assert.Equal(t, "abcd", token.AccessToken)
require.Contains(t, token.ExtraData, "id_token")
require.Contains(t, token.ExtraData, "id_token_claims")
require.NotEmpty(t, token.ExtraData["id_token"])
initialIDToken := token.ExtraData["id_token"]

token, err = ops.RefreshToken(ctx, token)
require.NoError(t, err)
require.NotNil(t, token)
assert.Equal(t, "ijkl", token.AccessToken)
require.Contains(t, token.ExtraData, "id_token")
require.Contains(t, token.ExtraData, "id_token_claims")
assert.NotEqual(t, initialIDToken, token.ExtraData["id_token"])
}

func TestOIDCRefreshWithoutIDToken(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: privateKey,
}, (&jose.SignerOptions{}).WithType("JWT"))
require.NoError(t, err)

h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
_, _ = io.WriteString(w, testOIDCConfiguration)
case "/.well-known/jwks.json":
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: &privateKey.PublicKey,
KeyID: "key",
Use: "sig",
},
},
})
case "/token":
b, err := ioutil.ReadAll(r.Body)
require.NoError(t, err)

data, err := url.ParseQuery(string(b))
require.NoError(t, err)

resp := make(url.Values)

switch data.Get("grant_type") {
case "authorization_code":
assert.Equal(t, "foo", data.Get("client_id"))
assert.Equal(t, "bar", data.Get("client_secret"))

idClaims := jwt.Claims{
Issuer: "http://localhost",
Audience: jwt.Audience{"foo"},
Subject: "test-user",
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}

idToken, err := jwt.Signed(signer).
Claims(idClaims).
Claims(map[string]interface{}{"grant_type": data.Get("grant_type")}).
CompactSerialize()
require.NoError(t, err)

resp.Set("access_token", "abcd")
resp.Set("refresh_token", "efgh")
resp.Set("token_type", "bearer")
resp.Set("id_token", idToken)
resp.Set("expires_in", "1")
case "refresh_token":
resp.Set("access_token", "ijkl")
resp.Set("refresh_token", "mnop")
resp.Set("token_type", "bearer")
resp.Set("expires_in", "900")
default:
assert.Fail(t, "unexpected grant type %q", data.Get("grant_type"))
}

_, _ = io.WriteString(w, resp.Encode())
default:
w.WriteHeader(http.StatusNotFound)
}
})
c := &http.Client{Transport: &testutil.MockRoundTripper{Handler: h}}
ctx = context.WithValue(ctx, oauth2.HTTPClient, c)

oidcTest, err := provider.GlobalRegistry.New(ctx, "oidc", map[string]string{
"issuer_url": "http://localhost",
"extra_data_fields": "id_token,id_token_claims",
})
require.NoError(t, err)

ops := oidcTest.Private("foo", "bar")

token, err := ops.AuthCodeExchange(ctx, "123456")
require.NoError(t, err)
require.NotNil(t, token)
assert.Equal(t, "abcd", token.AccessToken)
require.Contains(t, token.ExtraData, "id_token")
require.Contains(t, token.ExtraData, "id_token_claims")
require.NotEmpty(t, token.ExtraData["id_token"])
initialIDToken := token.ExtraData["id_token"]

token, err = ops.RefreshToken(ctx, token)
require.NoError(t, err)
require.NotNil(t, token)
assert.Equal(t, "ijkl", token.AccessToken)
require.Contains(t, token.ExtraData, "id_token")
require.Contains(t, token.ExtraData, "id_token_claims")
assert.Equal(t, initialIDToken, token.ExtraData["id_token"])
}

0 comments on commit a98a1e2

Please sign in to comment.