diff --git a/oidc/jwks.go b/oidc/jwks.go index 7df3e9f7..b1e3f7e3 100644 --- a/oidc/jwks.go +++ b/oidc/jwks.go @@ -159,7 +159,7 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ( // https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys keys, err := r.keysFromRemote(ctx) if err != nil { - return nil, fmt.Errorf("fetching keys %v", err) + return nil, fmt.Errorf("fetching keys %w", err) } for _, key := range keys { @@ -228,7 +228,7 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) { resp, err := doRequest(r.ctx, req) if err != nil { - return nil, fmt.Errorf("oidc: get keys failed %v", err) + return nil, fmt.Errorf("oidc: get keys failed %w", err) } defer resp.Body.Close() diff --git a/oidc/jwks_test.go b/oidc/jwks_test.go index 704bdc33..b9545ea6 100644 --- a/oidc/jwks_test.go +++ b/oidc/jwks_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -138,6 +139,41 @@ func TestMismatchedKeyID(t *testing.T) { testKeyVerify(t, key2, bad, key1, key2) } +func TestKeyVerifyContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + payload := []byte("a secret") + + good := newECDSAKey(t) + jws, err := jose.ParseSigned(good.sign(t, payload)) + if err != nil { + t.Fatal(err) + } + + ch := make(chan struct{}) + defer close(ch) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-ch + })) + defer s.Close() + + rks := newRemoteKeySet(ctx, s.URL, nil) + + cancel() + + // Ensure the token verifies. + _, err = rks.verify(ctx, jws) + if err == nil { + t.Fatal("expected context canceled, got nil error") + } + + if !errors.Is(err, context.Canceled) { + t.Errorf("expected error to be %q got %q", context.Canceled, err) + } +} + func testKeyVerify(t *testing.T, good, bad *signingKey, verification ...*signingKey) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -259,7 +295,7 @@ func BenchmarkVerify(b *testing.B) { key := newRSAKey(b) - now := time.Date(2022, 01, 29, 0, 0, 0, 0, time.UTC) + now := time.Date(2022, 1, 29, 0, 0, 0, 0, time.UTC) exp := now.Add(time.Hour) payload := []byte(fmt.Sprintf(`{ "iss": "https://example.com",