From d39836d11d1db073a375b273ac88ed7cc8560595 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 26 Jul 2022 22:21:05 +1000 Subject: [PATCH] Fix racing when loading new JWKs from multiple threads (#88753) This PR ensures the mutation of JWKs is done in a single thread and visible to all other threads, which in turn ensures validation to be correctly performed concurrently. Relates: #88023 --- .../xpack/security/authc/jwt/JwtRealm.java | 164 +++++++++++------- .../authc/jwt/JwtRealmAuthenticateTests.java | 4 +- .../security/authc/jwt/JwtRealmTestCase.java | 6 +- 3 files changed, 109 insertions(+), 65 deletions(-) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index 3486671bcaa58..d2efe1a6de76e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -65,6 +65,8 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable { private static final Logger LOGGER = LogManager.getLogger(JwtRealm.class); + private static final ContentAndJwksAlgs EMPTY_CONTENT_AND_JWKS_ALGS = new ContentAndJwksAlgs(null, new JwksAlgs(List.of(), List.of())); + // Cached authenticated users, and adjusted JWT expiration date (=exp+skew) for checking if the JWT expired before the cache entry record ExpiringUser(User user, Date exp) { ExpiringUser { @@ -109,7 +111,6 @@ boolean isEmpty() { final boolean isConfiguredJwkSetPkc; final boolean isConfiguredJwkSetHmac; final boolean isConfiguredJwkOidcHmac; - private final CloseableHttpAsyncClient httpClient; final JwkSetLoader jwkSetLoader; final TimeValue allowedClockSkew; final Boolean populateUserMetadata; @@ -125,9 +126,7 @@ boolean isEmpty() { final List allowedJwksAlgsPkc; final List allowedJwksAlgsHmac; DelegatedAuthorizationSupport delegatedAuthorizationSupport = null; - ContentAndJwksAlgs contentAndJwksAlgsPkc; ContentAndJwksAlgs contentAndJwksAlgsHmac; - final URI jwkSetPathUri; JwtRealm( final RealmConfig realmConfig, @@ -184,30 +183,17 @@ boolean isEmpty() { ); } - if (this.isConfiguredJwkSetPkc) { - final URI jwkSetPathUri = JwtUtil.parseHttpsUri(jwkSetPath); - if (jwkSetPathUri == null) { - this.jwkSetPathUri = null; // local file path - this.httpClient = null; - } else { - this.jwkSetPathUri = jwkSetPathUri; // HTTPS URL - this.httpClient = JwtUtil.createHttpClient(this.config, sslService); - } - this.jwkSetLoader = new JwkSetLoader(); // PKC JWKSet loader for HTTPS URL or local file path - } else { - this.jwkSetPathUri = null; // not configured - this.httpClient = null; - this.jwkSetLoader = null; - } - // Any exception during loading requires closing JwkSetLoader's HTTP client to avoid a thread pool leak try { this.contentAndJwksAlgsHmac = this.parseJwksAlgsHmac(); - this.contentAndJwksAlgsPkc = this.parseJwksAlgsPkc(); + if (this.isConfiguredJwkSetPkc) { + this.jwkSetLoader = new JwkSetLoader(sslService); // PKC JWKSet loader for HTTPS URL or local file path + } else { + this.jwkSetLoader = null; + } this.verifyAnyAvailableJwkAndAlgPair(); } catch (Throwable t) { - // ASSUME: Tests or startup only. Catch and rethrow Throwable here, in case some code throws an uncaught RuntimeException. - this.close(); + close(); throw t; } } @@ -255,14 +241,22 @@ private ContentAndJwksAlgs parseJwksAlgsHmac() { return new ContentAndJwksAlgs(hmacStringContentsSha256, jwksAlgsHmac); } - private ContentAndJwksAlgs parseJwksAlgsPkc() { + // Package private for test + URI getJwkSetPathUri() { + if (jwkSetLoader != null) { + return jwkSetLoader.jwkSetPathUri; + } else { + return null; + } + } + + // Package private for test + ContentAndJwksAlgs getJwksAlgsPkc() { if (this.isConfiguredJwkSetPkc == false) { - return new ContentAndJwksAlgs(null, new JwksAlgs(Collections.emptyList(), Collections.emptyList())); + return EMPTY_CONTENT_AND_JWKS_ALGS; } else { - // ASSUME: Blocking read operations are OK during startup - final PlainActionFuture future = new PlainActionFuture<>(); - this.jwkSetLoader.load(future); - return future.actionGet(); + assert jwkSetLoader != null; + return jwkSetLoader.contentAndJwksAlgsPkc; } } @@ -277,8 +271,8 @@ private Cache buildJwtCache() { private void verifyAnyAvailableJwkAndAlgPair() { assert this.contentAndJwksAlgsHmac != null : "HMAC not initialized"; - assert this.contentAndJwksAlgsPkc != null : "PKC not initialized"; - if (this.contentAndJwksAlgsHmac.jwksAlgs.isEmpty() && this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) { + assert getJwksAlgsPkc() != null : "PKC not initialized"; + if (this.contentAndJwksAlgsHmac.jwksAlgs.isEmpty() && this.getJwksAlgsPkc().jwksAlgs.isEmpty()) { final String msg = "No available JWK and algorithm for HMAC or PKC. Realm authentication expected to fail until this is fixed."; throw new SettingsException(msg); } @@ -312,7 +306,9 @@ public void initialize(final Iterable allRealms, final XPackLicenseState @Override public void close() { this.invalidateJwtCache(); - this.closeHttpClient(); + if (jwkSetLoader != null) { + jwkSetLoader.close(); + } } /** @@ -332,19 +328,6 @@ private void invalidateJwtCache() { } } - /** - * Clean up HTTPS client cache (if enabled). - */ - private void closeHttpClient() { - if (this.httpClient != null) { - try { - this.httpClient.close(); - } catch (IOException e) { - LOGGER.warn(() -> "Exception closing HTTPS client for realm [" + super.name() + "]", e); - } - } - } - @Override public void lookupUser(final String username, final ActionListener listener) { this.ensureInitialized(); @@ -596,7 +579,7 @@ private void validateSignature( try { JwtValidateUtil.validateSignature( jwt, - isJwtAlgHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.contentAndJwksAlgsPkc.jwksAlgs.jwks + isJwtAlgHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.getJwksAlgsPkc().jwksAlgs.jwks ); listener.onResponse(null); } catch (Exception primaryException) { @@ -609,33 +592,32 @@ private void validateSignature( () -> org.elasticsearch.core.Strings.format( "Signature verification failed for [%s] reloading JWKSet (was: #[%s] JWKs, #[%s] algs, sha256=[%s])", tokenPrincipal, - this.contentAndJwksAlgsPkc.jwksAlgs.jwks().size(), - this.contentAndJwksAlgsPkc.jwksAlgs.algs().size(), - MessageDigests.toHexString(this.contentAndJwksAlgsPkc.sha256()) + this.getJwksAlgsPkc().jwksAlgs.jwks().size(), + this.getJwksAlgsPkc().jwksAlgs.algs().size(), + MessageDigests.toHexString(this.getJwksAlgsPkc().sha256()) ), primaryException ); - this.jwkSetLoader.load(ActionListener.wrap(newContentAndJwksAlgs -> { - if (Arrays.equals(this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) { + this.jwkSetLoader.reload(ActionListener.wrap(isUpdated -> { + if (false == isUpdated) { // No change in JWKSet logger.debug("Reloaded same PKC JWKs, can't retry verify JWT token=[{}]", tokenPrincipal); listener.onFailure(primaryException); return; } - this.contentAndJwksAlgsPkc = newContentAndJwksAlgs; // If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated. // Enhancement idea: Use separate caches for PKC vs HMAC JWKs, so only PKC entries get invalidated. // Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs. this.invalidateJwtCache(); - if (this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) { + if (this.getJwksAlgsPkc().jwksAlgs.isEmpty()) { logger.debug("Reloaded empty PKC JWKs, verification of JWT token will fail [{}]", tokenPrincipal); // fall through and let try/catch below handle empty JWKs failure log and response } try { - JwtValidateUtil.validateSignature(jwt, this.contentAndJwksAlgsPkc.jwksAlgs.jwks); + JwtValidateUtil.validateSignature(jwt, this.getJwksAlgsPkc().jwksAlgs.jwks); listener.onResponse(null); } catch (Exception secondaryException) { logger.debug( @@ -660,14 +642,73 @@ public void usageStats(final ActionListener> listener) { }, listener::onFailure)); } - private class JwkSetLoader { + private class JwkSetLoader implements Releasable { private final AtomicReference> reloadFutureRef = new AtomicReference<>(); + private final URI jwkSetPathUri; + private final CloseableHttpAsyncClient httpClient; + private volatile ContentAndJwksAlgs contentAndJwksAlgsPkc; + + JwkSetLoader(final SSLService sslService) { + assert JwtRealm.this.isConfiguredJwkSetPkc; + final URI jwkSetPathUri = JwtUtil.parseHttpsUri(jwkSetPath); + if (jwkSetPathUri == null) { + this.jwkSetPathUri = null; // local file path + this.httpClient = null; + } else { + this.jwkSetPathUri = jwkSetPathUri; // HTTPS URL + this.httpClient = JwtUtil.createHttpClient(JwtRealm.this.config, sslService); + } + // Any exception during loading requires closing JwkSetLoader's HTTP client to avoid a thread pool leak + try { + final PlainActionFuture future = new PlainActionFuture<>(); + load(future); + // ASSUME: Blocking read operations are OK during startup + contentAndJwksAlgsPkc = future.actionGet(); + } catch (Throwable t) { + close(); + throw t; + } + } + + /** + * Clean up HTTPS client cache (if enabled). + */ + @Override + public void close() { + if (httpClient != null) { + try { + httpClient.close(); + } catch (IOException e) { + LOGGER.warn(() -> "Exception closing HTTPS client for realm [" + JwtRealm.this.name() + "]", e); + } + } + } + /** + * Load the JWK sets and pass its content to the specified listener. + */ void load(final ActionListener listener) { final ListenableFuture future = this.getFuture(); future.addListener(listener); } + /** + * Reload the JWK sets, compare to existing JWK sets and update it to the reloaded value if + * they are different. The listener is called with false if the reloaded content is the same + * as the existing one or true if they are different. + */ + void reload(final ActionListener listener) { + load(ActionListener.wrap(newContentAndJwksAlgs -> { + if (Arrays.equals(contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) { + // No change in JWKSet + listener.onResponse(false); + } else { + contentAndJwksAlgsPkc = newContentAndJwksAlgs; + listener.onResponse(true); + } + }, listener::onFailure)); + } + private ListenableFuture getFuture() { for (;;) { final ListenableFuture existingFuture = this.reloadFutureRef.get(); @@ -677,7 +718,10 @@ private ListenableFuture getFuture() { final ListenableFuture newFuture = new ListenableFuture<>(); if (this.reloadFutureRef.compareAndSet(null, newFuture)) { - loadInternal(ActionListener.runAfter(newFuture, () -> this.reloadFutureRef.compareAndSet(newFuture, null))); + loadInternal(ActionListener.runAfter(newFuture, () -> { + final ListenableFuture oldValue = this.reloadFutureRef.getAndSet(null); + assert oldValue == newFuture : "future reference changed unexpectedly"; + })); return newFuture; } // else, Another thread set the future-ref before us, just try it all again @@ -686,7 +730,7 @@ private ListenableFuture getFuture() { private void loadInternal(final ActionListener listener) { // PKC JWKSet get contents from local file or remote HTTPS URL - if (JwtRealm.this.httpClient == null) { + if (httpClient == null) { LOGGER.trace("Loading PKC JWKs from path [{}]", JwtRealm.this.jwkSetPath); listener.onResponse( this.parseContent( @@ -698,13 +742,13 @@ private void loadInternal(final ActionListener listener) { ) ); } else { - LOGGER.trace("Loading PKC JWKs from https URI [{}]", JwtRealm.this.jwkSetPathUri); + LOGGER.trace("Loading PKC JWKs from https URI [{}]", jwkSetPathUri); JwtUtil.readUriContents( RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH), - JwtRealm.this.jwkSetPathUri, - JwtRealm.this.httpClient, + jwkSetPathUri, + httpClient, listener.map(bytes -> { - LOGGER.trace("Loaded bytes [{}] from [{}]", bytes.length, JwtRealm.this.jwkSetPathUri); + LOGGER.trace("Loaded bytes [{}] from [{}]", bytes.length, jwkSetPathUri); return this.parseContent(bytes); }) ); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java index 8a8b955a81385..f948a4221499d 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java @@ -412,10 +412,10 @@ public void testJwtValidationFailures() throws Exception { { // Verify rejection of a tampered header (flip HMAC=>RSA or RSA/EC=>HMAC) final String mixupAlg; // Check if there are any algorithms available in the realm for attempting a flip test if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(validHeader.getAlgorithm().getName())) { - if (jwtIssuerAndRealm.realm().contentAndJwksAlgsPkc.jwksAlgs().algs().isEmpty()) { + if (jwtIssuerAndRealm.realm().getJwksAlgsPkc().jwksAlgs().algs().isEmpty()) { mixupAlg = null; // cannot flip HMAC to PKC (no PKC algs available) } else { - mixupAlg = randomFrom(jwtIssuerAndRealm.realm().contentAndJwksAlgsPkc.jwksAlgs().algs()); // flip HMAC to PKC + mixupAlg = randomFrom(jwtIssuerAndRealm.realm().getJwksAlgsPkc().jwksAlgs().algs()); // flip HMAC to PKC } } else { if (jwtIssuerAndRealm.realm().contentAndJwksAlgsHmac.jwksAlgs().algs().isEmpty()) { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java index ec64938a6e34c..8c8700d1b02c0 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java @@ -213,7 +213,7 @@ protected JwtIssuer createJwtIssuer( } protected void copyIssuerJwksToRealmConfig(final JwtIssuerAndRealm jwtIssuerAndRealm) throws Exception { - if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.jwkSetPathUri == null)) { + if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.getJwkSetPathUri() == null)) { LOGGER.trace("Updating JwtRealm PKC public JWKSet local file"); final Path path = PathUtils.get(jwtIssuerAndRealm.realm.jwkSetPath); Files.writeString(path, jwtIssuerAndRealm.issuer.encodedJwkSetPkcPublic); @@ -659,7 +659,7 @@ protected void printJwtRealm(final JwtRealm jwtRealm) { + ", algsPkc=" + jwtRealm.allowedJwksAlgsPkc + ", filteredPkc=" - + jwtRealm.contentAndJwksAlgsPkc.jwksAlgs().algs() + + jwtRealm.getJwksAlgsPkc().jwksAlgs().algs() + ", claimPrincipal=[" + jwtRealm.claimParserPrincipal.getClaimName() + "], claimGroups=[" @@ -675,7 +675,7 @@ protected void printJwtRealm(final JwtRealm jwtRealm) { for (final JWK jwk : jwtRealm.contentAndJwksAlgsHmac.jwksAlgs().jwks()) { LOGGER.info("REALM HMAC: jwk=[{}]", jwk); } - for (final JWK jwk : jwtRealm.contentAndJwksAlgsPkc.jwksAlgs().jwks()) { + for (final JWK jwk : jwtRealm.getJwksAlgsPkc().jwksAlgs().jwks()) { LOGGER.info("REALM PKC: jwk=[{}]", jwk); } }