From 9b22a17d7a231c23ee70f9055d4b9d0a09620838 Mon Sep 17 00:00:00 2001 From: Trevor Foster Date: Sun, 10 Nov 2024 16:37:20 -0500 Subject: [PATCH] fix: cpu contention when reading JWKs and suppress generating duplicate JWKs Previously each concurrent caller would need to lock a shared mutex when reading or writing a given JWK set. The read path now doesn't require locking a mutex at all and instead returns valid query results directly. The write path is now protected by a concurrency control mechanism (using x/sync/singleflight) to ensure only one JWK set is generated and persisted. Note: Duplicate JWK sets may still be improperly generated if running more than one Hydra instance in a high traffic environment. --- Makefile | 1 + cmd/server/helper_cert.go | 4 +- hsm/manager_hsm.go | 4 +- hsm/manager_nohsm.go | 3 +- internal/mock/config_cookie.go | 4 +- jwk/handler.go | 15 +---- jwk/helper.go | 68 ++++++++------------- jwk/helper_test.go | 30 +++++---- jwk/jwt_strategy.go | 4 +- jwk/manager.go | 3 +- jwk/manager_mock_test.go | 4 +- jwk/manager_strategy.go | 3 +- jwk/registry_mock_test.go | 4 +- oauth2/oauth2_provider_mock_test.go | 2 +- persistence/sql/migratest/migration_test.go | 6 +- persistence/sql/persister_jwk.go | 36 ++++++++--- 16 files changed, 94 insertions(+), 97 deletions(-) diff --git a/Makefile b/Makefile index 75b912e0521..09e1b7d574f 100644 --- a/Makefile +++ b/Makefile @@ -109,6 +109,7 @@ format: .bin/goimports .bin/ory node_modules mocks: .bin/mockgen mockgen -package oauth2_test -destination oauth2/oauth2_provider_mock_test.go github.com/ory/fosite OAuth2Provider mockgen -package jwk_test -destination jwk/registry_mock_test.go -source=jwk/registry.go + mockgen -package jwk_test -destination jwk/manager_mock_test.go -source=jwk/manager.go go generate ./... # Generates the SDKs diff --git a/cmd/server/helper_cert.go b/cmd/server/helper_cert.go index 6cef67bc362..7c5579e9636 100644 --- a/cmd/server/helper_cert.go +++ b/cmd/server/helper_cert.go @@ -12,8 +12,6 @@ import ( "encoding/pem" "sync" - "github.com/gofrs/uuid" - "github.com/go-jose/go-jose/v3" "github.com/ory/hydra/v2/driver" @@ -58,7 +56,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface con } // no certificates configured: self-sign a new cert - priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256") + priv, err := jwk.GetOrGenerateKeySetPrivateKey(ctx, d.SoftwareKeyManager(), TlsKeyName, "", "RS256") if err != nil { d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair") return nil // in case Fatal is hooked diff --git a/hsm/manager_hsm.go b/hsm/manager_hsm.go index 75badb1cc5f..1a8301b026b 100644 --- a/hsm/manager_hsm.go +++ b/hsm/manager_hsm.go @@ -16,14 +16,16 @@ import ( "net/http" "sync" - "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/otelx" + "github.com/ory/hydra/v2/driver/config" + "github.com/pkg/errors" "github.com/pborman/uuid" "github.com/ory/fosite" + "github.com/ory/hydra/v2/jwk" "github.com/miekg/pkcs11" diff --git a/hsm/manager_nohsm.go b/hsm/manager_nohsm.go index 4a28bc425e9..bdef2305855 100644 --- a/hsm/manager_nohsm.go +++ b/hsm/manager_nohsm.go @@ -10,9 +10,10 @@ import ( "context" "sync" - "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/logrusx" + "github.com/ory/hydra/v2/driver/config" + "github.com/pkg/errors" "github.com/ory/hydra/v2/jwk" diff --git a/internal/mock/config_cookie.go b/internal/mock/config_cookie.go index 5fab6d1d7dc..d146e10cd6e 100644 --- a/internal/mock/config_cookie.go +++ b/internal/mock/config_cookie.go @@ -1,8 +1,8 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ory/hydra/x (interfaces: CookieConfigProvider) +// Source: github.com/ory/hydra/v2/x (interfaces: CookieConfigProvider) // Package mock is a generated GoMock package. package mock diff --git a/jwk/handler.go b/jwk/handler.go index 7d48445321e..b8028aad205 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -12,9 +12,6 @@ import ( "github.com/ory/herodot" "github.com/ory/x/httprouterx" - "github.com/gofrs/uuid" - "github.com/pkg/errors" - "github.com/ory/x/urlx" "github.com/ory/x/errorsx" @@ -101,17 +98,11 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) { for _, set := range wellKnownKeys { set := set eg.Go(func() error { - k, err := h.r.KeyManager().GetKeySet(ctx, set) - if errors.Is(err, x.ErrNotFound) { - h.r.Logger().Warnf("JSON Web Key Set %q does not exist yet, generating new key pair...", set) - k, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig") - if err != nil { - return err - } - } else if err != nil { + keySet, err := GetOrGenerateKeySet(ctx, h.r.KeyManager(), set, "", string(jose.RS256)) + if err != nil { return err } - keys <- ExcludePrivateKeys(k) + keys <- ExcludePrivateKeys(keySet) return nil }) } diff --git a/jwk/helper.go b/jwk/helper.go index 50f3a28b2d2..b04ce1b06ed 100644 --- a/jwk/helper.go +++ b/jwk/helper.go @@ -12,69 +12,51 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" - "sync" - - hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/x/josex" - "github.com/ory/x/errorsx" - + hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/x" + "github.com/ory/x/errorsx" + jose "github.com/go-jose/go-jose/v3" "github.com/pkg/errors" ) -var mapLock sync.RWMutex -var locks = map[string]*sync.RWMutex{} - -func getLock(set string) *sync.RWMutex { - mapLock.Lock() - defer mapLock.Unlock() - if _, ok := locks[set]; !ok { - locks[set] = new(sync.RWMutex) - } - return locks[set] -} - func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error { - _, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg) + _, err := GetOrGenerateKeySetPrivateKey(ctx, r.KeyManager(), set, set, alg) return err } -func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) { - getLock(set).Lock() - defer getLock(set).Unlock() - - keys, err := m.GetKeySet(ctx, set) - if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 { - r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set) - keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - if err != nil { - return nil, err - } - } else if err != nil { +func GetOrGenerateKeySetPrivateKey(ctx context.Context, m Manager, set, kid, alg string) (*jose.JSONWebKey, error) { + keySet, err := GetOrGenerateKeySet(ctx, m, set, kid, alg) + if err != nil { return nil, err } - privKey, privKeyErr := FindPrivateKey(keys) - if privKeyErr == nil { + privKey, err := FindPrivateKey(keySet) + if err == nil { return privKey, nil - } else { - r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) + } - keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - if err != nil { - return nil, err - } + keySet, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") + if err != nil { + return nil, err + } - privKey, err := FindPrivateKey(keys) - if err != nil { - return nil, err - } - return privKey, nil + return FindPrivateKey(keySet) +} + +func GetOrGenerateKeySet(ctx context.Context, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) { + keys, err := m.GetKeySet(ctx, set) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return nil, err + } else if keys != nil && len(keys.Keys) > 0 { + return keys, nil } + + return m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") } func First(keys []jose.JSONWebKey) *jose.JSONWebKey { diff --git a/jwk/helper_test.go b/jwk/helper_test.go index c1a5ee46387..c217c36f931 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -17,20 +17,19 @@ import ( "strings" "testing" + gomock "github.com/golang/mock/gomock" + "github.com/pborman/uuid" + "github.com/pkg/errors" + hydra "github.com/ory/hydra-client-go/v2" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/cryptosigner" - "github.com/golang/mock/gomock" - "github.com/pborman/uuid" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" - "github.com/ory/x/contextx" ) type fakeSigner struct { @@ -210,7 +209,6 @@ func TestExcludeOpaquePrivateKeys(t *testing.T) { func TestGetOrGenerateKeys(t *testing.T) { t.Parallel() - reg := internal.NewMockedRegistry(t, &contextx.Default{}) setId := uuid.NewUUID().String() keyId := uuid.NewUUID().String() @@ -226,46 +224,46 @@ func TestGetOrGenerateKeys(t *testing.T) { return NewMockManager(ctrl) } - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, "")) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256") assert.NoError(t, err) assert.Equal(t, privKey, &keySet.Keys[0]) }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil).Times(1) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "key not found") }) diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index 6154066459b..3d9f9790c91 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -10,9 +10,9 @@ import ( "github.com/ory/x/josex" "github.com/go-jose/go-jose/v3" - "github.com/gofrs/uuid" "github.com/ory/fosite" + "github.com/ory/hydra/v2/driver/config" "github.com/pkg/errors" @@ -40,7 +40,7 @@ func NewDefaultJWTSigner(c *config.DefaultProvider, r InternalRegistry, setID st } func (j *DefaultJWTSigner) getKeys(ctx context.Context) (private *jose.JSONWebKey, err error) { - private, err = GetOrGenerateKeys(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) + private, err = GetOrGenerateKeySetPrivateKey(ctx, j.r.KeyManager(), j.setID, "", string(jose.RS256)) if err == nil { return private, nil } diff --git a/jwk/manager.go b/jwk/manager.go index a8f3c6aacb1..d061c8b53ad 100644 --- a/jwk/manager.go +++ b/jwk/manager.go @@ -11,9 +11,10 @@ import ( "github.com/pkg/errors" + "github.com/ory/x/errorsx" + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/x" - "github.com/ory/x/errorsx" jose "github.com/go-jose/go-jose/v3" "github.com/gofrs/uuid" diff --git a/jwk/manager_mock_test.go b/jwk/manager_mock_test.go index 65627a0e595..a7b98168a88 100644 --- a/jwk/manager_mock_test.go +++ b/jwk/manager_mock_test.go @@ -1,10 +1,10 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 // Code generated by MockGen. DO NOT EDIT. // Source: jwk/manager.go -// Package mock_jwk is a generated GoMock package. +// Package jwk_test is a generated GoMock package. package jwk_test import ( diff --git a/jwk/manager_strategy.go b/jwk/manager_strategy.go index 2519ba3d151..386b8bb44e2 100644 --- a/jwk/manager_strategy.go +++ b/jwk/manager_strategy.go @@ -12,8 +12,9 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "github.com/ory/hydra/v2/x" "github.com/ory/x/otelx" + + "github.com/ory/hydra/v2/x" ) const tracingComponent = "github.com/ory/hydra/v2/jwk" diff --git a/jwk/registry_mock_test.go b/jwk/registry_mock_test.go index c305fd18167..f9624dc2b75 100644 --- a/jwk/registry_mock_test.go +++ b/jwk/registry_mock_test.go @@ -1,4 +1,4 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 // Code generated by MockGen. DO NOT EDIT. @@ -13,7 +13,7 @@ import ( gomock "github.com/golang/mock/gomock" herodot "github.com/ory/herodot" - "github.com/ory/hydra/v2/aead" + aead "github.com/ory/hydra/v2/aead" config "github.com/ory/hydra/v2/driver/config" jwk "github.com/ory/hydra/v2/jwk" logrusx "github.com/ory/x/logrusx" diff --git a/oauth2/oauth2_provider_mock_test.go b/oauth2/oauth2_provider_mock_test.go index 83d584eb12f..a5af84c5d5b 100644 --- a/oauth2/oauth2_provider_mock_test.go +++ b/oauth2/oauth2_provider_mock_test.go @@ -1,4 +1,4 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 // Code generated by MockGen. DO NOT EDIT. diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 8564cfab969..4bb8cbf7161 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -13,9 +13,10 @@ import ( "testing" "time" - "github.com/ory/hydra/v2/internal" "github.com/ory/x/contextx" + "github.com/ory/hydra/v2/internal" + "github.com/bradleyjkemp/cupaloy/v2" "github.com/fatih/structs" "github.com/gofrs/uuid" @@ -28,10 +29,11 @@ import ( "github.com/ory/x/networkx" "github.com/ory/x/sqlxx" + "github.com/ory/x/popx" + "github.com/ory/hydra/v2/flow" testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid" "github.com/ory/hydra/v2/persistence/sql" - "github.com/ory/x/popx" "github.com/ory/x/sqlcon/dockertest" diff --git a/persistence/sql/persister_jwk.go b/persistence/sql/persister_jwk.go index 27a6e184a2b..64215d1c140 100644 --- a/persistence/sql/persister_jwk.go +++ b/persistence/sql/persister_jwk.go @@ -6,19 +6,22 @@ package sql import ( "context" "encoding/json" + "strings" "github.com/go-jose/go-jose/v3" "github.com/gobuffalo/pop/v6" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/singleflight" "github.com/ory/x/errorsx" "github.com/ory/x/otelx" "github.com/pkg/errors" - "github.com/ory/hydra/v2/jwk" "github.com/ory/x/sqlcon" + + "github.com/ory/hydra/v2/jwk" ) var _ jwk.Manager = &Persister{} @@ -31,17 +34,34 @@ func (p *Persister) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, attribute.String("alg", alg))) defer otelx.End(span, &err) - keys, err := jwk.GenerateJWK(ctx, jose.SignatureAlgorithm(alg), kid, use) - if err != nil { - return nil, errors.Wrapf(jwk.ErrUnsupportedKeyAlgorithm, "%s", err) - } + return p.generateKeySet(ctx, set, kid, alg, use) +} + +var jwkGenFlightGroup singleflight.Group + +func (p *Persister) generateKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + networkID := p.NetworkID(ctx) - err = p.AddKeySet(ctx, set, keys) + concurrencyKey := strings.Join([]string{networkID.String(), set, kid, alg, use}, ":") + + // Suppress duplicate key set generation jobs where the networkID,set,kid,alg,use match. + keysResult, err, _ := jwkGenFlightGroup.Do(concurrencyKey, func() (any, error) { + keys, err := jwk.GenerateJWK(ctx, jose.SignatureAlgorithm(alg), kid, use) + if err != nil { + return nil, errors.Wrapf(jwk.ErrUnsupportedKeyAlgorithm, "%s", err) + } + + err = p.AddKeySet(ctx, set, keys) + if err != nil { + return nil, err + } + + return keys, nil + }) if err != nil { return nil, err } - - return keys, nil + return keysResult.(*jose.JSONWebKeySet), nil } func (p *Persister) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) (err error) {