diff --git a/CHANGELOG.md b/CHANGELOG.md index 52e6fe3..3ae8b5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ Versioning](https://semver.org/spec/v2.0.0.html). refresh token flow if the plugin user specifies a nonce to validate against; otherwise, it is assumed that the nonce data is invalid or non-conforming to the OpenID Connect Core specification. +* Per the OpenID Connect Core specification, ID tokens will only be revalidated + during refresh if the server sends a new ID token. Otherwise, they are passed + through unmodified from the original exchange. ### Changed diff --git a/pkg/provider/oidc.go b/pkg/provider/oidc.go index f7461eb..2173411 100644 --- a/pkg/provider/oidc.go +++ b/pkg/provider/oidc.go @@ -34,9 +34,9 @@ type oidcOperations struct { extraDataFields []string } -func (oo *oidcOperations) verifyUpdateToken(ctx context.Context, t *Token, nonce string) error { +func (oo *oidcOperations) verifyUpdateIDToken(ctx context.Context, t *Token, nonce string) error { rawIDToken, ok := t.Extra("id_token").(string) - if !ok { + if !ok || rawIDToken == "" { return ErrOIDCMissingIDToken } @@ -54,34 +54,50 @@ func (oo *oidcOperations) verifyUpdateToken(ctx context.Context, t *Token, nonce return ErrOIDCNonceMismatch } - if len(oo.extraDataFields) > 0 { - t.ExtraData = make(map[string]interface{}) - - for _, field := range oo.extraDataFields { - switch field { - case oidcExtraDataFieldIDToken: - t.ExtraData[field] = rawIDToken - case oidcExtraDataFieldIDTokenClaims: - claims := make(map[string]interface{}) - if err := idToken.Claims(&claims); err != nil { - return fmt.Errorf("oidc: error parsing token claims: %w", err) - } - - t.ExtraData[field] = claims - case oidcExtraDataFieldUserInfo: - userInfo, err := oo.p.UserInfo(ctx, oo.basicOperations.base.TokenSource(ctx, t.Token)) - if err != nil { - return fmt.Errorf("oidc: error fetching user info: %w", err) - } - - claims := make(map[string]interface{}) - if err := userInfo.Claims(&claims); err != nil { - return fmt.Errorf("oidc: error parsing user info: %w", err) - } - - t.ExtraData[field] = claims + for _, field := range oo.extraDataFields { + switch field { + case oidcExtraDataFieldIDToken: + t.ExtraData[field] = rawIDToken + case oidcExtraDataFieldIDTokenClaims: + claims := make(map[string]interface{}) + if err := idToken.Claims(&claims); err != nil { + return fmt.Errorf("oidc: error parsing token claims: %w", err) } + + t.ExtraData[field] = claims + } + } + + return nil +} + +func (oo *oidcOperations) copyIDToken(ctx context.Context, p, n *Token) { + for _, field := range oo.extraDataFields { + switch field { + case oidcExtraDataFieldIDToken, oidcExtraDataFieldIDTokenClaims: + n.ExtraData[field] = p.ExtraData[field] + } + } +} + +func (oo *oidcOperations) updateUserInfo(ctx context.Context, t *Token) error { + for _, field := range oo.extraDataFields { + if field != oidcExtraDataFieldUserInfo { + continue + } + + userInfo, err := oo.p.UserInfo(ctx, oauth2.StaticTokenSource(t.Token)) + if err != nil { + return fmt.Errorf("oidc: error fetching user info: %w", err) + } + + claims := make(map[string]interface{}) + if err := userInfo.Claims(&claims); err != nil { + return fmt.Errorf("oidc: error parsing user info: %w", err) } + + t.ExtraData[field] = claims + break } return nil @@ -96,7 +112,15 @@ func (oo *oidcOperations) AuthCodeExchange(ctx context.Context, code string, opt return nil, err } - if err := oo.verifyUpdateToken(ctx, t, o.ProviderOptions["nonce"]); err != nil { + if t.ExtraData == nil { + t.ExtraData = make(map[string]interface{}) + } + + if err := oo.verifyUpdateIDToken(ctx, t, o.ProviderOptions["nonce"]); err != nil { + return nil, errmark.MarkUser(err) + } + + if err := oo.updateUserInfo(ctx, t); err != nil { return nil, errmark.MarkUser(err) } @@ -107,16 +131,32 @@ func (oo *oidcOperations) RefreshToken(ctx context.Context, t *Token, opts ...Re o := &RefreshTokenOptions{} o.ApplyOptions(opts) - t, err := oo.basicOperations.RefreshToken(ctx, t, opts...) + nt, err := oo.basicOperations.RefreshToken(ctx, t, opts...) if err != nil { return nil, err } - if err := oo.verifyUpdateToken(ctx, t, o.ProviderOptions["nonce"]); err != nil { + if nt.ExtraData == nil { + nt.ExtraData = make(map[string]interface{}) + } + + // Per OpenID Connect Core 1.0 ยง 12.2 + // (https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse), + // providing an ID token as part of a refresh is optional. We will only + // revalidate the token if a new one is provided. + if rawIDToken, ok := nt.Extra("id_token").(string); ok && rawIDToken != "" { + if err := oo.verifyUpdateIDToken(ctx, nt, o.ProviderOptions["nonce"]); err != nil { + return nil, errmark.MarkUser(err) + } + } else { + oo.copyIDToken(ctx, t, nt) + } + + if err := oo.updateUserInfo(ctx, t); err != nil { return nil, errmark.MarkUser(err) } - return t, nil + return nt, nil } type oidc struct { diff --git a/pkg/provider/oidc_test.go b/pkg/provider/oidc_test.go index 32d5a98..2d0bacd 100644 --- a/pkg/provider/oidc_test.go +++ b/pkg/provider/oidc_test.go @@ -145,3 +145,207 @@ func TestOIDCFlow(t *testing.T) { assert.Equal(t, "test-user", userInfo["sub"]) assert.Equal(t, "test-user@example.com", userInfo["email"]) } + +func TestOIDCRefreshWithIDToken(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.RS256, + Key: privateKey, + }, (&jose.SignerOptions{}).WithType("JWT")) + require.NoError(t, err) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + _, _ = io.WriteString(w, testOIDCConfiguration) + case "/.well-known/jwks.json": + _ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: &privateKey.PublicKey, + KeyID: "key", + Use: "sig", + }, + }, + }) + case "/token": + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + data, err := url.ParseQuery(string(b)) + require.NoError(t, err) + + resp := make(url.Values) + + idClaims := jwt.Claims{ + Issuer: "http://localhost", + Audience: jwt.Audience{"foo"}, + Subject: "test-user", + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + } + + idToken, err := jwt.Signed(signer). + Claims(idClaims). + Claims(map[string]interface{}{"grant_type": data.Get("grant_type")}). + CompactSerialize() + require.NoError(t, err) + + resp.Set("token_type", "bearer") + resp.Set("id_token", idToken) + + switch data.Get("grant_type") { + case "authorization_code": + assert.Equal(t, "foo", data.Get("client_id")) + assert.Equal(t, "bar", data.Get("client_secret")) + + resp.Set("access_token", "abcd") + resp.Set("refresh_token", "efgh") + resp.Set("expires_in", "1") + case "refresh_token": + resp.Set("access_token", "ijkl") + resp.Set("refresh_token", "mnop") + resp.Set("expires_in", "900") + default: + assert.Fail(t, "unexpected grant type %q", data.Get("grant_type")) + } + + _, _ = io.WriteString(w, resp.Encode()) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + c := &http.Client{Transport: &testutil.MockRoundTripper{Handler: h}} + ctx = context.WithValue(ctx, oauth2.HTTPClient, c) + + oidcTest, err := provider.GlobalRegistry.New(ctx, "oidc", map[string]string{ + "issuer_url": "http://localhost", + "extra_data_fields": "id_token,id_token_claims", + }) + require.NoError(t, err) + + ops := oidcTest.Private("foo", "bar") + + token, err := ops.AuthCodeExchange(ctx, "123456") + require.NoError(t, err) + require.NotNil(t, token) + assert.Equal(t, "abcd", token.AccessToken) + require.Contains(t, token.ExtraData, "id_token") + require.Contains(t, token.ExtraData, "id_token_claims") + require.NotEmpty(t, token.ExtraData["id_token"]) + initialIDToken := token.ExtraData["id_token"] + + token, err = ops.RefreshToken(ctx, token) + require.NoError(t, err) + require.NotNil(t, token) + assert.Equal(t, "ijkl", token.AccessToken) + require.Contains(t, token.ExtraData, "id_token") + require.Contains(t, token.ExtraData, "id_token_claims") + assert.NotEqual(t, initialIDToken, token.ExtraData["id_token"]) +} + +func TestOIDCRefreshWithoutIDToken(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.RS256, + Key: privateKey, + }, (&jose.SignerOptions{}).WithType("JWT")) + require.NoError(t, err) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + _, _ = io.WriteString(w, testOIDCConfiguration) + case "/.well-known/jwks.json": + _ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: &privateKey.PublicKey, + KeyID: "key", + Use: "sig", + }, + }, + }) + case "/token": + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + data, err := url.ParseQuery(string(b)) + require.NoError(t, err) + + resp := make(url.Values) + + switch data.Get("grant_type") { + case "authorization_code": + assert.Equal(t, "foo", data.Get("client_id")) + assert.Equal(t, "bar", data.Get("client_secret")) + + idClaims := jwt.Claims{ + Issuer: "http://localhost", + Audience: jwt.Audience{"foo"}, + Subject: "test-user", + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + } + + idToken, err := jwt.Signed(signer). + Claims(idClaims). + Claims(map[string]interface{}{"grant_type": data.Get("grant_type")}). + CompactSerialize() + require.NoError(t, err) + + resp.Set("access_token", "abcd") + resp.Set("refresh_token", "efgh") + resp.Set("token_type", "bearer") + resp.Set("id_token", idToken) + resp.Set("expires_in", "1") + case "refresh_token": + resp.Set("access_token", "ijkl") + resp.Set("refresh_token", "mnop") + resp.Set("token_type", "bearer") + resp.Set("expires_in", "900") + default: + assert.Fail(t, "unexpected grant type %q", data.Get("grant_type")) + } + + _, _ = io.WriteString(w, resp.Encode()) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + c := &http.Client{Transport: &testutil.MockRoundTripper{Handler: h}} + ctx = context.WithValue(ctx, oauth2.HTTPClient, c) + + oidcTest, err := provider.GlobalRegistry.New(ctx, "oidc", map[string]string{ + "issuer_url": "http://localhost", + "extra_data_fields": "id_token,id_token_claims", + }) + require.NoError(t, err) + + ops := oidcTest.Private("foo", "bar") + + token, err := ops.AuthCodeExchange(ctx, "123456") + require.NoError(t, err) + require.NotNil(t, token) + assert.Equal(t, "abcd", token.AccessToken) + require.Contains(t, token.ExtraData, "id_token") + require.Contains(t, token.ExtraData, "id_token_claims") + require.NotEmpty(t, token.ExtraData["id_token"]) + initialIDToken := token.ExtraData["id_token"] + + token, err = ops.RefreshToken(ctx, token) + require.NoError(t, err) + require.NotNil(t, token) + assert.Equal(t, "ijkl", token.AccessToken) + require.Contains(t, token.ExtraData, "id_token") + require.Contains(t, token.ExtraData, "id_token_claims") + assert.Equal(t, initialIDToken, token.ExtraData["id_token"]) +}