From d91f12b3b34bb35b5188d77100a1fd878be5d840 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 25 Jul 2022 14:49:45 +1000 Subject: [PATCH] Fix racing when loading new JWKs from multiple threads The change ensure the mutation of JWKs is done in a single thread and visible to all other threads, which in turn ensures validation to be correctly performed concurrently. Relates: #88023 --- .../xpack/security/authc/jwt/JwtRealm.java | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index 3486671bcaa58..907114a2dcc62 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -125,7 +125,7 @@ boolean isEmpty() { final List allowedJwksAlgsPkc; final List allowedJwksAlgsHmac; DelegatedAuthorizationSupport delegatedAuthorizationSupport = null; - ContentAndJwksAlgs contentAndJwksAlgsPkc; + volatile ContentAndJwksAlgs contentAndJwksAlgsPkc; ContentAndJwksAlgs contentAndJwksAlgsHmac; final URI jwkSetPathUri; @@ -616,14 +616,13 @@ private void validateSignature( primaryException ); - this.jwkSetLoader.load(ActionListener.wrap(newContentAndJwksAlgs -> { - if (Arrays.equals(this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) { + this.jwkSetLoader.reload(ActionListener.wrap(isUpdated -> { + if (false == isUpdated) { // No change in JWKSet logger.debug("Reloaded same PKC JWKs, can't retry verify JWT token=[{}]", tokenPrincipal); listener.onFailure(primaryException); return; } - this.contentAndJwksAlgsPkc = newContentAndJwksAlgs; // If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated. // Enhancement idea: Use separate caches for PKC vs HMAC JWKs, so only PKC entries get invalidated. // Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs. @@ -663,11 +662,31 @@ public void usageStats(final ActionListener> listener) { private class JwkSetLoader { private final AtomicReference> reloadFutureRef = new AtomicReference<>(); + /** + * Load the JWK sets and pass its content to the specified listener. + */ void load(final ActionListener listener) { final ListenableFuture future = this.getFuture(); future.addListener(listener); } + /** + * Reload the JWK sets, compare to existing JWK sets and update it to the reloaded value if + * they are different. The listener is called with false if the reloaded content is the same + * as the existing one or true if they are different. + */ + void reload(final ActionListener listener) { + load(ActionListener.wrap(newContentAndJwksAlgs -> { + if (Arrays.equals(JwtRealm.this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) { + // No change in JWKSet + listener.onResponse(false); + } else { + JwtRealm.this.contentAndJwksAlgsPkc = newContentAndJwksAlgs; + listener.onResponse(true); + } + }, listener::onFailure)); + } + private ListenableFuture getFuture() { for (;;) { final ListenableFuture existingFuture = this.reloadFutureRef.get();