From 0c33d0fdaacb3f1b5f635b5e3ea6aad3a0fcde7d Mon Sep 17 00:00:00 2001 From: "Adam T. Williams" Date: Thu, 7 Nov 2024 11:48:38 -0700 Subject: [PATCH] fix: rm locks and suppress duplicate writes --- cmd/server/helper_cert.go | 2 +- jwk/handler.go | 13 ++----- jwk/helper.go | 79 +++++++++++++++++++-------------------- jwk/helper_test.go | 23 ++++++------ jwk/jwt_strategy.go | 3 +- 5 files changed, 57 insertions(+), 63 deletions(-) diff --git a/cmd/server/helper_cert.go b/cmd/server/helper_cert.go index 6cef67bc362..e2012292be2 100644 --- a/cmd/server/helper_cert.go +++ b/cmd/server/helper_cert.go @@ -58,7 +58,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface con } // no certificates configured: self-sign a new cert - priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256") + priv, err := jwk.GetOrGenerateKeySetPrivateKey(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256") if err != nil { d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair") return nil // in case Fatal is hooked diff --git a/jwk/handler.go b/jwk/handler.go index 7d48445321e..5e12f9f31ad 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -13,7 +13,6 @@ import ( "github.com/ory/x/httprouterx" "github.com/gofrs/uuid" - "github.com/pkg/errors" "github.com/ory/x/urlx" @@ -101,17 +100,11 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) { for _, set := range wellKnownKeys { set := set eg.Go(func() error { - k, err := h.r.KeyManager().GetKeySet(ctx, set) - if errors.Is(err, x.ErrNotFound) { - h.r.Logger().Warnf("JSON Web Key Set %q does not exist yet, generating new key pair...", set) - k, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig") - if err != nil { - return err - } - } else if err != nil { + keySet, err := GetOrGenerateKeySet(ctx, h.r, h.r.KeyManager(), set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) + if err != nil { return err } - keys <- ExcludePrivateKeys(k) + keys <- ExcludePrivateKeys(keySet) return nil }) } diff --git a/jwk/helper.go b/jwk/helper.go index 50f3a28b2d2..5b3e3347b87 100644 --- a/jwk/helper.go +++ b/jwk/helper.go @@ -12,69 +12,68 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" - "sync" + + "golang.org/x/sync/singleflight" hydra "github.com/ory/hydra-client-go/v2" + "github.com/ory/hydra/v2/x" "github.com/ory/x/josex" "github.com/ory/x/errorsx" - "github.com/ory/hydra/v2/x" - jose "github.com/go-jose/go-jose/v3" "github.com/pkg/errors" ) -var mapLock sync.RWMutex -var locks = map[string]*sync.RWMutex{} - -func getLock(set string) *sync.RWMutex { - mapLock.Lock() - defer mapLock.Unlock() - if _, ok := locks[set]; !ok { - locks[set] = new(sync.RWMutex) - } - return locks[set] -} +var jwkGenFlightGroup singleflight.Group func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error { - _, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg) + _, err := GetOrGenerateKeySetPrivateKey(ctx, r, r.KeyManager(), set, set, alg) return err } -func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) { - getLock(set).Lock() - defer getLock(set).Unlock() - - keys, err := m.GetKeySet(ctx, set) - if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 { - r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set) - keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - if err != nil { - return nil, err - } - } else if err != nil { +func GetOrGenerateKeySetPrivateKey(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) { + keySet, err := GetOrGenerateKeySet(ctx, r, m, set, kid, alg) + if err != nil { return nil, err } - privKey, privKeyErr := FindPrivateKey(keys) - if privKeyErr == nil { + privKey, err := FindPrivateKey(keySet) + if err == nil { return privKey, nil - } else { - r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) + } - keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") - if err != nil { - return nil, err - } + keySet, err = generateKeySet(ctx, r, m, set, kid, alg) + if err != nil { + return nil, err + } - privKey, err := FindPrivateKey(keys) - if err != nil { - return nil, err - } - return privKey, nil + return FindPrivateKey(keySet) +} + +func GetOrGenerateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) { + keys, err := m.GetKeySet(ctx, set) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return nil, err + } else if keys != nil && len(keys.Keys) > 0 { + jwkGenFlightGroup.Forget(set + alg + kid) + return keys, nil + } + + return generateKeySet(ctx, r, m, set, kid, alg) +} + +func generateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) { + // Suppress duplicate key set generation jobs where the set+alg match. + keysResult, err, _ := jwkGenFlightGroup.Do(set+alg+kid, func() (any, error) { + r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) + return m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") + }) + if err != nil { + return nil, err } + return keysResult.(*jose.JSONWebKeySet), nil } func First(keys []jose.JSONWebKey) *jose.JSONWebKey { diff --git a/jwk/helper_test.go b/jwk/helper_test.go index c1a5ee46387..c4f4d6e18b5 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -27,10 +27,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ory/x/contextx" + "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/x" - "github.com/ory/x/contextx" ) type fakeSigner struct { @@ -226,46 +227,46 @@ func TestGetOrGenerateKeys(t *testing.T) { return NewMockManager(ctrl) } - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, "")) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "GetKeySetError") }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.NoError(t, err) assert.Equal(t, privKey, &keySet.Keys[0]) }) - t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { keyManager := km(t) keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil).Times(1) - privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256") assert.Nil(t, privKey) assert.EqualError(t, err, "key not found") }) diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index 6154066459b..9fdd6c48374 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -13,6 +13,7 @@ import ( "github.com/gofrs/uuid" "github.com/ory/fosite" + "github.com/ory/hydra/v2/driver/config" "github.com/pkg/errors" @@ -40,7 +41,7 @@ func NewDefaultJWTSigner(c *config.DefaultProvider, r InternalRegistry, setID st } func (j *DefaultJWTSigner) getKeys(ctx context.Context) (private *jose.JSONWebKey, err error) { - private, err = GetOrGenerateKeys(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) + private, err = GetOrGenerateKeySetPrivateKey(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256)) if err == nil { return private, nil }