diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index 7115673115e3c..767d16d910957 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -15,7 +15,6 @@ import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; @@ -45,6 +44,7 @@ import org.elasticsearch.xpack.security.authc.support.ClaimParser; import org.elasticsearch.xpack.security.authc.support.DelegatedAuthorizationSupport; +import java.io.Closeable; import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; @@ -111,7 +111,7 @@ boolean isEmpty() { final boolean isConfiguredJwkSetPkc; final boolean isConfiguredJwkSetHmac; final boolean isConfiguredJwkOidcHmac; - final CloseableHttpAsyncClient httpClient; + final JwkSetLoader jwkSetLoader; final TimeValue allowedClockSkew; final Boolean populateUserMetadata; final ClaimParser claimParserPrincipal; @@ -128,7 +128,6 @@ boolean isEmpty() { DelegatedAuthorizationSupport delegatedAuthorizationSupport = null; ContentAndJwksAlgs contentAndJwksAlgsPkc; ContentAndJwksAlgs contentAndJwksAlgsHmac; - private final AtomicReference> reloadActionRef = new AtomicReference<>(); JwtRealm( final RealmConfig realmConfig, @@ -181,14 +180,9 @@ boolean isEmpty() { } if (this.isConfiguredJwkSetPkc) { - final URI jwkSetPathPkcUri = JwtUtil.parseHttpsUri(this.jwkSetPath); - if (jwkSetPathPkcUri == null) { - this.httpClient = null; // local file means no HTTP client - } else { - this.httpClient = JwtUtil.createHttpClient(super.config, sslService); - } + this.jwkSetLoader = new JwkSetLoader(sslService); } else { - this.httpClient = null; // no setting means no HTTP client + this.jwkSetLoader = null; // no setting means nothing to load } // Split configured signature algorithms by PKC and HMAC. Useful during validation, error logging, and JWK vs Alg filtering. @@ -205,6 +199,7 @@ boolean isEmpty() { this.close(); throw t; } + } private Cache buildJwtCache() { @@ -260,41 +255,11 @@ private ContentAndJwksAlgs parseJwksAlgsHmac() { } private ContentAndJwksAlgs parseJwksAlgsPkc() { - final JwtRealm.JwksAlgs jwksAlgsPkc; - byte[] jwkSetContentsPkcSha256 = null; if (this.isConfiguredJwkSetPkc == false) { - jwksAlgsPkc = new JwtRealm.JwksAlgs(Collections.emptyList(), Collections.emptyList()); + return new ContentAndJwksAlgs(null, new JwksAlgs(Collections.emptyList(), Collections.emptyList())); } else { - // PKC JWKSet get contents from local file or remote HTTPS URL - final byte[] jwkSetContentBytesPkc; - if (this.httpClient == null) { - jwkSetContentBytesPkc = JwtUtil.readFileContents( - RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.PKC_JWKSET_PATH), - this.jwkSetPath, - super.config.env() - ); - } else { - final URI jwkSetPathPkcUri = JwtUtil.parseHttpsUri(this.jwkSetPath); - jwkSetContentBytesPkc = JwtUtil.readUriContents( - RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.PKC_JWKSET_PATH), - jwkSetPathPkcUri, - this.httpClient - ); - } - final String jwkSetContentsPkc = new String(jwkSetContentBytesPkc, StandardCharsets.UTF_8); - jwkSetContentsPkcSha256 = sha256(jwkSetContentsPkc); - - // PKC JWKSet parse contents - final List jwksPkc = JwkValidateUtil.loadJwksFromJwkSetString( - RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.PKC_JWKSET_PATH), - jwkSetContentsPkc - ); - - // Filter JWK(s) vs signature algorithms. Only keep JWKs with a matching alg. Only keep algs with a matching JWK. - jwksAlgsPkc = JwkValidateUtil.filterJwksAndAlgorithms(jwksPkc, this.allowedJwksAlgsPkc); + return this.jwkSetLoader.loadBlocking(); } - LOGGER.info("Usable PKC: JWKs [{}]. Algorithms [{}].", jwksAlgsPkc.jwks().size(), String.join(",", jwksAlgsPkc.algs())); - return new ContentAndJwksAlgs(jwkSetContentsPkcSha256, jwksAlgsPkc); } private void verifyAnyAvailableJwkAndAlgPair() { @@ -334,7 +299,7 @@ public void initialize(final Iterable allRealms, final XPackLicenseState @Override public void close() { this.invalidateJwtCache(); - this.closeHttpClient(); + this.closeJwkSetLoader(); } /** @@ -357,13 +322,9 @@ private void invalidateJwtCache() { /** * Clean up HTTPS client cache (if enabled). */ - private void closeHttpClient() { - if (this.httpClient != null) { - try { - this.httpClient.close(); - } catch (IOException e) { - LOGGER.warn(() -> "Exception closing HTTPS client for realm [" + super.name() + "]", e); - } + private void closeJwkSetLoader() { + if (this.jwkSetLoader != null) { + this.jwkSetLoader.close(); } } @@ -475,217 +436,196 @@ public void authenticate(final AuthenticationToken authenticationToken, final Ac } // Validate JWT: Extract JWT and claims set, and validate JWT. - final SignedJWT jwt; - final JWSHeader header; - final JWTClaimsSet claimsSet; - try { - jwt = SignedJWT.parse(serializedJwt.toString()); - header = jwt.getHeader(); - claimsSet = jwt.getJWTClaimsSet(); - final Date now = new Date(); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug( - "Realm [{}] JWT parse succeeded for token=[{}]." - + "Validating JWT, now [{}], alg [{}], issuer [{}], audiences [{}], kty [{}]," - + " auth_time [{}], iat [{}], nbf [{}], exp [{}], kid [{}], jti [{}]", - super.name(), - tokenPrincipal, - now, - header.getAlgorithm(), - claimsSet.getIssuer(), - claimsSet.getAudience(), - header.getType(), - claimsSet.getDateClaim("auth_time"), - claimsSet.getIssueTime(), - claimsSet.getNotBeforeTime(), - claimsSet.getExpirationTime(), - header.getKeyID(), - claimsSet.getJWTID() - ); - } - // Validate all else before signature, because these checks are more helpful diagnostics than rejected signatures. - final boolean isJwtSigHmac = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(header.getAlgorithm().getName()); - JwtValidateUtil.validateType(jwt); - JwtValidateUtil.validateIssuer(jwt, allowedIssuer); - JwtValidateUtil.validateAudiences(jwt, allowedAudiences); - JwtValidateUtil.validateSignatureAlgorithm(jwt, isJwtSigHmac ? this.allowedJwksAlgsHmac : this.allowedJwksAlgsPkc); - JwtValidateUtil.validateAuthTime(jwt, now, this.allowedClockSkew.seconds()); - JwtValidateUtil.validateIssuedAtTime(jwt, now, this.allowedClockSkew.seconds()); - JwtValidateUtil.validateNotBeforeTime(jwt, now, this.allowedClockSkew.seconds()); - JwtValidateUtil.validateExpiredTime(jwt, now, this.allowedClockSkew.seconds()); - - // At this point, client authc and JWT kty+alg+iss+aud+time filters passed. Do sig last, in case JWK reload is expensive. - try { - JwtValidateUtil.validateSignature( - jwt, - isJwtSigHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.contentAndJwksAlgsPkc.jwksAlgs.jwks - ); - } catch (Exception originalValidateSignatureException) { - if (isJwtSigHmac) { - throw originalValidateSignatureException; // HMAC reload not supported at this time - } - final String sigErr = originalValidateSignatureException.getMessage() + " "; - - ListenableFuture reloadAction = this.reloadActionRef.get(); // shared by threads using this realm - final PlainActionFuture reloadListener = PlainActionFuture.newFuture(); // local thread - final PlainActionFuture reloadListener2 = PlainActionFuture.newFuture(); // local thread - while (reloadAction == null) { - reloadAction = new ListenableFuture<>(); // current thread will try to take charge of reload - boolean isCurrentThread = this.reloadActionRef.compareAndSet(null, reloadAction); - if (isCurrentThread == false) { - reloadAction = this.reloadActionRef.get(); // different thread took charge, get the shared reference - } - reloadAction.addListener(reloadListener); - reloadAction.addListener(reloadListener2); - if (isCurrentThread) { - try { - LOGGER.trace(sigErr + "Reloading PKC JWKs to retry verify JWT token=[" + tokenPrincipal + "]"); - final ContentAndJwksAlgs newContentAndJwksAlgs; - try { - newContentAndJwksAlgs = this.parseJwksAlgsPkc(); - } catch (Exception reloadException) { - final String msg = sigErr - + "Failed to reload PKC JWKs, can't retry verify JWT token=[" - + tokenPrincipal - + "]"; - reloadException.addSuppressed(originalValidateSignatureException); - LOGGER.error(msg, reloadException); - listener.onResponse(AuthenticationResult.unsuccessful(msg, reloadException)); - return; - } - final boolean isSame = Arrays.equals(this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256); - if (isSame) { - LOGGER.debug(sigErr + "Reloaded same PKC JWKs to verify JWT token=[" + tokenPrincipal + "]"); - } else { - LOGGER.debug(sigErr + "Reloaded different PKC JWKs to verify JWT token=[" + tokenPrincipal + "]"); - this.contentAndJwksAlgsPkc = newContentAndJwksAlgs; - - // If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated. - // Enhancement idea: Use separate caches for PKC vs HMAC JWKs, so only PKC entries get invalidated. - // Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs. - this.invalidateJwtCache(); - } - reloadAction.onResponse(isSame); - } catch (Exception e) { - final String msg = sigErr - + "Failed to reload PKC JWKs, can't retry verify JWT token=[" - + tokenPrincipal - + "]"; - e.addSuppressed(originalValidateSignatureException); - LOGGER.error(msg, e); - reloadAction.onFailure(e); - } finally { - this.reloadActionRef.set(null); - } - } - } + validateJwt( + serializedJwt, + tokenPrincipal, + ActionListener.wrap(claimsSet -> processValidatedJwt(tokenPrincipal, jwtCacheKey, claimsSet, listener), ex -> { + final String msg = "Realm [" + super.name() + "] JWT validation failed for token=[" + tokenPrincipal + "]."; + LOGGER.debug(msg, ex); + listener.onResponse(AuthenticationResult.unsuccessful(msg, ex)); + }) + ); + } else { + final String className = (authenticationToken == null) ? "null" : authenticationToken.getClass().getCanonicalName(); + final String msg = "Realm [" + super.name() + "] does not support AuthenticationToken [" + className + "]."; + LOGGER.trace(msg); + listener.onResponse(AuthenticationResult.unsuccessful(msg, null)); + } + } - LOGGER.trace(sigErr + "Waiting for reload of PKC JWKs to retry verify JWT token=[" + tokenPrincipal + "]"); - final boolean isSame = reloadListener.actionGet(); // wait for action to complete - final boolean isSame2 = reloadListener2.actionGet(); // wait for action to complete - assert isSame == isSame2; - if (isSame) { - final String msg = sigErr + "Reloaded same PKC JWKs, can't retry verify JWT token=[" + tokenPrincipal + "]"; - final ElasticsearchException reloadPkcException = new ElasticsearchException(msg); - reloadPkcException.addSuppressed(originalValidateSignatureException); - LOGGER.debug(msg, reloadPkcException); - listener.onResponse(AuthenticationResult.unsuccessful(msg, reloadPkcException)); - return; - } else if (this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) { - LOGGER.error(sigErr + "Reloaded empty PKC JWKs, can't retry verify JWT token=[" + tokenPrincipal + "]"); - // allow empty, filtered PKC JWKs to fall through to try/catch below, to reuse that error handling - } - // different PKC JWKs detected so retry signature - try { - JwtValidateUtil.validateSignature(jwt, this.contentAndJwksAlgsPkc.jwksAlgs.jwks); - } catch (Exception e) { - final String msg = sigErr - + "Realm [" - + super.name() - + "] JWT validation retry failed for token=[" - + tokenPrincipal - + "]."; - final ElasticsearchException reloadException = new ElasticsearchException(msg); - reloadException.addSuppressed(originalValidateSignatureException); - LOGGER.debug(msg, e); - listener.onResponse(AuthenticationResult.unsuccessful(msg, e)); - return; - } - } - } catch (Exception e) { - final String msg = "Realm [" + super.name() + "] JWT validation failed for token=[" + tokenPrincipal + "]."; - LOGGER.debug(msg, e); - listener.onResponse(AuthenticationResult.unsuccessful(msg, e)); - return; + private void validateJwt(SecureString serializedJwt, String tokenPrincipal, ActionListener listener) { + final SignedJWT jwt; + final JWSHeader header; + final JWTClaimsSet claimsSet; + try { + jwt = SignedJWT.parse(serializedJwt.toString()); + header = jwt.getHeader(); + claimsSet = jwt.getJWTClaimsSet(); + final Date now = new Date(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Realm [{}] JWT parse succeeded for token=[{}]." + + "Validating JWT, now [{}], alg [{}], issuer [{}], audiences [{}], kty [{}]," + + " auth_time [{}], iat [{}], nbf [{}], exp [{}], kid [{}], jti [{}]", + super.name(), + tokenPrincipal, + now, + header.getAlgorithm(), + claimsSet.getIssuer(), + claimsSet.getAudience(), + header.getType(), + claimsSet.getDateClaim("auth_time"), + claimsSet.getIssueTime(), + claimsSet.getNotBeforeTime(), + claimsSet.getExpirationTime(), + header.getKeyID(), + claimsSet.getJWTID() + ); } + // Validate all else before signature, because these checks are more helpful diagnostics than rejected signatures. + JwtValidateUtil.validateType(jwt); + JwtValidateUtil.validateIssuer(jwt, allowedIssuer); + JwtValidateUtil.validateAudiences(jwt, allowedAudiences); + JwtValidateUtil.validateSignatureAlgorithm(jwt, isHmacSignature(header) ? this.allowedJwksAlgsHmac : this.allowedJwksAlgsPkc); + JwtValidateUtil.validateAuthTime(jwt, now, this.allowedClockSkew.seconds()); + JwtValidateUtil.validateIssuedAtTime(jwt, now, this.allowedClockSkew.seconds()); + JwtValidateUtil.validateNotBeforeTime(jwt, now, this.allowedClockSkew.seconds()); + JwtValidateUtil.validateExpiredTime(jwt, now, this.allowedClockSkew.seconds()); + + // At this point, client authc and JWT kty+alg+iss+aud+time filters passed. Do sig last, in case JWK reload is expensive. + validateSignature(tokenPrincipal, jwt, listener.map(ignored -> claimsSet)); + + } catch (Exception e) { + listener.onFailure(e); + } + } - // At this point, JWT is validated. Parse the JWT claims using realm settings. - - final String principal = this.claimParserPrincipal.getClaimValue(claimsSet); - if (Strings.hasText(principal) == false) { - final String msg = "Realm [" - + super.name() - + "] no principal for token=[" - + tokenPrincipal - + "] parser=[" - + this.claimParserPrincipal - + "] claims=[" - + claimsSet - + "]."; - LOGGER.debug(msg); - listener.onResponse(AuthenticationResult.unsuccessful(msg, null)); - return; - } + private boolean isHmacSignature(JWSHeader header) { + return JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(header.getAlgorithm().getName()); + } - // Roles listener: Log roles from delegated authz lookup or role mapping, and cache User if JWT cache is enabled. - final ActionListener> logAndCacheListener = ActionListener.wrap(result -> { - if (result.isAuthenticated()) { - final User user = result.getValue(); - LOGGER.debug( - () -> format("Realm [%s] roles [%s] for principal=[%s].", super.name(), join(",", user.roles()), principal) - ); - if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) { - try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) { - final long expWallClockMillis = claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis(); - this.jwtCache.put(jwtCacheKey, new ExpiringUser(result.getValue(), new Date(expWallClockMillis))); - } + private void processValidatedJwt( + String tokenPrincipal, + BytesKey jwtCacheKey, + JWTClaimsSet claimsSet, + ActionListener> listener + ) { + // At this point, JWT is validated. Parse the JWT claims using realm settings. + final String principal = this.claimParserPrincipal.getClaimValue(claimsSet); + if (Strings.hasText(principal) == false) { + final String msg = "Realm [" + + super.name() + + "] no principal for token=[" + + tokenPrincipal + + "] parser=[" + + this.claimParserPrincipal + + "] claims=[" + + claimsSet + + "]."; + LOGGER.debug(msg); + listener.onResponse(AuthenticationResult.unsuccessful(msg, null)); + return; + } + + // Roles listener: Log roles from delegated authz lookup or role mapping, and cache User if JWT cache is enabled. + final ActionListener> logAndCacheListener = ActionListener.wrap(result -> { + if (result.isAuthenticated()) { + final User user = result.getValue(); + LOGGER.debug(() -> format("Realm [%s] roles [%s] for principal=[%s].", super.name(), join(",", user.roles()), principal)); + if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) { + try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) { + final long expWallClockMillis = claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis(); + this.jwtCache.put(jwtCacheKey, new ExpiringUser(result.getValue(), new Date(expWallClockMillis))); } } - listener.onResponse(result); - }, listener::onFailure); - - // Delegated role lookup or Role mapping: Use the above listener to log roles and cache User. - if (this.delegatedAuthorizationSupport.hasDelegation()) { - this.delegatedAuthorizationSupport.resolve(principal, logAndCacheListener); - return; } + listener.onResponse(result); + }, listener::onFailure); - // User metadata: If enabled, extract metadata from JWT claims set. Use it in UserRoleMapper.UserData and User constructors. - final Map userMetadata; - try { - userMetadata = this.populateUserMetadata ? JwtUtil.toUserMetadata(jwt) : Map.of(); - } catch (Exception e) { - final String msg = "Realm [" + super.name() + "] parse metadata failed for principal=[" + principal + "]."; - LOGGER.debug(msg, e); - listener.onResponse(AuthenticationResult.unsuccessful(msg, e)); + // Delegated role lookup or Role mapping: Use the above listener to log roles and cache User. + if (this.delegatedAuthorizationSupport.hasDelegation()) { + this.delegatedAuthorizationSupport.resolve(principal, logAndCacheListener); + return; + } + + // User metadata: If enabled, extract metadata from JWT claims set. Use it in UserRoleMapper.UserData and User constructors. + final Map userMetadata; + try { + userMetadata = this.populateUserMetadata ? JwtUtil.toUserMetadata(claimsSet) : Map.of(); + } catch (Exception e) { + final String msg = "Realm [" + super.name() + "] parse metadata failed for principal=[" + principal + "]."; + LOGGER.debug(msg, e); + listener.onResponse(AuthenticationResult.unsuccessful(msg, e)); + return; + } + + // Role resolution: Handle role mapping in JWT Realm. + final List groups = this.claimParserGroups.getClaimValues(claimsSet); + final String dn = this.claimParserDn.getClaimValue(claimsSet); + final String mail = this.claimParserMail.getClaimValue(claimsSet); + final String name = this.claimParserName.getClaimValue(claimsSet); + final UserRoleMapper.UserData userData = new UserRoleMapper.UserData(principal, dn, groups, userMetadata, super.config); + this.userRoleMapper.resolveRoles(userData, ActionListener.wrap(rolesSet -> { + final User user = new User(principal, rolesSet.toArray(Strings.EMPTY_ARRAY), name, mail, userData.getMetadata(), true); + logAndCacheListener.onResponse(AuthenticationResult.success(user)); + }, logAndCacheListener::onFailure)); + } + + private void validateSignature(String tokenPrincipal, SignedJWT jwt, ActionListener listener) throws Exception { + final boolean isJwtSigHmac = isHmacSignature(jwt.getHeader()); + try { + JwtValidateUtil.validateSignature( + jwt, + isJwtSigHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.contentAndJwksAlgsPkc.jwksAlgs.jwks + ); + listener.onResponse(null); + } catch (Exception originalValidateSignatureException) { + if (isJwtSigHmac || this.jwkSetLoader == null) { + listener.onFailure(originalValidateSignatureException);// HMAC reload not supported at this time return; } - // Role resolution: Handle role mapping in JWT Realm. - final List groups = this.claimParserGroups.getClaimValues(claimsSet); - final String dn = this.claimParserDn.getClaimValue(claimsSet); - final String mail = this.claimParserMail.getClaimValue(claimsSet); - final String name = this.claimParserName.getClaimValue(claimsSet); - final UserRoleMapper.UserData userData = new UserRoleMapper.UserData(principal, dn, groups, userMetadata, super.config); - this.userRoleMapper.resolveRoles(userData, ActionListener.wrap(rolesSet -> { - final User user = new User(principal, rolesSet.toArray(Strings.EMPTY_ARRAY), name, mail, userData.getMetadata(), true); - logAndCacheListener.onResponse(AuthenticationResult.success(user)); - }, logAndCacheListener::onFailure)); - } else { - final String className = (authenticationToken == null) ? "null" : authenticationToken.getClass().getCanonicalName(); - final String msg = "Realm [" + super.name() + "] does not support AuthenticationToken [" + className + "]."; - LOGGER.trace(msg); - listener.onResponse(AuthenticationResult.unsuccessful(msg, null)); + LOGGER.debug( + () -> org.elasticsearch.core.Strings.format( + "Signature verification failed for [%s] reloading JWKSet (was: #[%s] JWKs, #[%s] algs, sha256=[%s])", + tokenPrincipal, + contentAndJwksAlgsPkc.jwksAlgs.jwks().size(), + contentAndJwksAlgsPkc.jwksAlgs.algs().size(), + MessageDigests.toHexString(contentAndJwksAlgsPkc.sha256()) + ), + originalValidateSignatureException + ); + + this.jwkSetLoader.load(ActionListener.wrap(newContentAndJwksAlgs -> { + if (Arrays.equals(this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) { + // No change in JWKSet + logger.debug("Reloaded same PKC JWKs, can't retry verify JWT token=[{}]", tokenPrincipal); + listener.onFailure(originalValidateSignatureException); + return; + } + this.contentAndJwksAlgsPkc = newContentAndJwksAlgs; + // If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated. + // Enhancement idea: Use separate caches for PKC vs HMAC JWKs, so only PKC entries get invalidated. + // Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs. + this.invalidateJwtCache(); + + if (this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) { + logger.debug("Reloaded empty PKC JWKs, verification of JWT token will fail [{}]", tokenPrincipal); + } + + try { + JwtValidateUtil.validateSignature(jwt, this.contentAndJwksAlgsPkc.jwksAlgs.jwks); + listener.onResponse(null); + } catch (Exception secondaryException) { + logger.debug( + "Verification of JWT token for [{}] failed - original failure=[{}], failure after reload=[{}]", + tokenPrincipal, + originalValidateSignatureException.getMessage(), + secondaryException.getMessage() + ); + listener.onFailure(secondaryException); + } + }, listener::onFailure)); } } @@ -703,4 +643,111 @@ static byte[] sha256(final CharSequence charSequence) { messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8)); return messageDigest.digest(); } + + class JwkSetLoader implements Closeable { + private final CloseableHttpAsyncClient httpClient; + private final URI uri; + + private final AtomicReference> reloadFutureRef; + + private JwkSetLoader(SSLService sslService) { + final URI uri = JwtUtil.parseHttpsUri(jwkSetPath); + if (uri == null) { + this.uri = null; + this.httpClient = null; // local file means no HTTP client + } else { + this.uri = uri; + this.httpClient = JwtUtil.createHttpClient(config, sslService); + } + reloadFutureRef = new AtomicReference<>(); + } + + void load(ActionListener listener) { + ListenableFuture future = getFuture(); + future.addListener(listener); + } + + ContentAndJwksAlgs loadBlocking() { + var future = new PlainActionFuture(); + load(future); + return future.actionGet(); + } + + private ListenableFuture getFuture() { + for (;;) { + ListenableFuture existingFuture = reloadFutureRef.get(); + if (existingFuture != null) { + return existingFuture; + } + + ListenableFuture newFuture = new ListenableFuture<>(); + if (reloadFutureRef.compareAndSet(null, newFuture)) { + loadInternal(ActionListener.runAfter(newFuture, () -> reloadFutureRef.compareAndSet(newFuture, null))); + return newFuture; + } + // else, Another thread set the future-ref before us, just try it all again + } + } + + private void loadInternal(final ActionListener listener) { + // PKC JWKSet get contents from local file or remote HTTPS URL + if (this.httpClient == null) { + LOGGER.trace("Loading PKC JWKs from path [{}]", jwkSetPath); + listener.onResponse( + parseContent( + JwtUtil.readFileContents( + RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH), + JwtRealm.this.jwkSetPath, + JwtRealm.this.config.env() + ) + ) + ); + } else { + LOGGER.trace("Loading PKC JWKs from https URI [{}]", uri); + JwtUtil.readUriContents( + RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH), + uri, + this.httpClient, + listener.map(bytes -> { + LOGGER.trace("Loaded bytes [{}] from [{}]", bytes.length, uri); + return parseContent(bytes); + }) + ); + } + } + + private ContentAndJwksAlgs parseContent(byte[] jwkSetContentBytesPkc) { + final String jwkSetContentsPkc = new String(jwkSetContentBytesPkc, StandardCharsets.UTF_8); + final byte[] jwkSetContentsPkcSha256 = sha256(jwkSetContentsPkc); + + // PKC JWKSet parse contents + final List jwksPkc = JwkValidateUtil.loadJwksFromJwkSetString( + RealmSettings.getFullSettingKey(config, JwtRealmSettings.PKC_JWKSET_PATH), + jwkSetContentsPkc + ); + // Filter JWK(s) vs signature algorithms. Only keep JWKs with a matching alg. Only keep algs with a matching JWK. + final JwksAlgs jwksAlgsPkc = JwkValidateUtil.filterJwksAndAlgorithms(jwksPkc, allowedJwksAlgsPkc); + LOGGER.info( + "Usable PKC: JWKs=[{}] algorithms=[{}] sha256=[{}]", + jwksAlgsPkc.jwks().size(), + String.join(",", jwksAlgsPkc.algs()), + MessageDigests.toHexString(jwkSetContentsPkcSha256) + ); + return new ContentAndJwksAlgs(jwkSetContentsPkcSha256, jwksAlgsPkc); + } + + public void close() { + if (this.httpClient != null) { + try { + this.httpClient.close(); + } catch (IOException e) { + LOGGER.warn(() -> "Exception closing HTTPS client for realm [" + JwtRealm.this.name() + "]", e); + } + } + } + + public boolean isFile() { + return this.uri == null; + } + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java index d358e41401df7..95e2a9b533d1c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java @@ -11,7 +11,6 @@ import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.util.JSONObjectUtils; import com.nimbusds.jwt.JWTClaimsSet; -import com.nimbusds.jwt.SignedJWT; import org.apache.http.HttpEntity; import org.apache.http.HttpResponse; @@ -33,7 +32,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.SpecialPermission; -import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SettingsException; @@ -185,16 +184,25 @@ public static URI parseHttpsUri(final String uriString) { return null; } - public static byte[] readUriContents( + public static void readUriContents( final String jwkSetConfigKeyPkc, final URI jwkSetPathPkcUri, - final CloseableHttpAsyncClient httpClient - ) throws SettingsException { - try { - return JwtUtil.readBytes(httpClient, jwkSetPathPkcUri); - } catch (Exception e) { - throw new SettingsException("Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].", e); - } + final CloseableHttpAsyncClient httpClient, + final ActionListener listener + ) { + JwtUtil.readBytes( + httpClient, + jwkSetPathPkcUri, + ActionListener.wrap( + listener::onResponse, + ex -> listener.onFailure( + new SettingsException( + "Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].", + ex + ) + ) + ) + ); } public static byte[] readFileContents(final String jwkSetConfigKeyPkc, final String jwkSetPathPkc, final Environment environment) @@ -262,13 +270,11 @@ public static CloseableHttpAsyncClient createHttpClient(final RealmConfig realmC } /** - * Use the HTTP Client to get URL content bytes up to N max bytes. + * Use the HTTP Client to get URL content bytes. * @param httpClient Configured HTTP/HTTPS client. * @param uri URI to download. - * @return Byte array of the URI contents up to N max bytes. */ - public static byte[] readBytes(final CloseableHttpAsyncClient httpClient, final URI uri) { - final PlainActionFuture plainActionFuture = PlainActionFuture.newFuture(); + public static void readBytes(final CloseableHttpAsyncClient httpClient, final URI uri, ActionListener listener) { AccessController.doPrivileged((PrivilegedAction) () -> { httpClient.execute(new HttpGet(uri), new FutureCallback<>() { @Override @@ -278,12 +284,12 @@ public void completed(final HttpResponse result) { if (statusCode == 200) { final HttpEntity entity = result.getEntity(); try (InputStream inputStream = entity.getContent()) { - plainActionFuture.onResponse(inputStream.readAllBytes()); + listener.onResponse(inputStream.readAllBytes()); } catch (Exception e) { - plainActionFuture.onFailure(e); + listener.onFailure(e); } } else { - plainActionFuture.onFailure( + listener.onFailure( new ElasticsearchSecurityException( "Get [" + uri + "] failed, status [" + statusCode + "], reason [" + statusLine.getReasonPhrase() + "]." ) @@ -293,17 +299,16 @@ public void completed(final HttpResponse result) { @Override public void failed(Exception e) { - plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e)); + listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e)); } @Override public void cancelled() { - plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled.")); + listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled.")); } }); return null; }); - return plainActionFuture.actionGet(); } public static Path resolvePath(final Environment environment, final String jwkSetPath) { @@ -335,14 +340,10 @@ public static SecureString join(final CharSequence delimiter, final CharSequence * JWSHeader: Header are not support. * JWTClaimsSet: Claims are supported. Claim keys are prefixed by "jwt_claim_". * Base64URL: Signature is not supported. - * @param jwt SignedJWT object. * @return Map of formatted and filtered values to be used as user metadata. - * @throws Exception Parse error. */ - // // Values will be filtered by type using isAllowedTypeForClaim(). - public static Map toUserMetadata(final SignedJWT jwt) throws Exception { - final JWTClaimsSet claimsSet = jwt.getJWTClaimsSet(); + public static Map toUserMetadata(JWTClaimsSet claimsSet) { return claimsSet.getClaims() .entrySet() .stream() diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java index 415f3cca147de..da7e9b8ec34ac 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.core.security.authc.AuthenticationResult; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmSettings; @@ -97,6 +98,7 @@ public void testJwtAuthcRealmAuthcAuthzWithoutAuthzRealms() throws Exception { * Test with updated/removed/restored JWKs. * @throws Exception Unexpected test failure */ + @TestLogging(value = "org.elasticsearch.xpack.security.authc.jwt:TRACE", reason = "debug") public void testJwkSetUpdates() throws Exception { this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs( this.createJwtRealmsSettingsBuilder(), diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java index 4225b2b7da7d1..3740ea088f0d7 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java @@ -213,7 +213,7 @@ protected JwtIssuer createJwtIssuer( } protected void copyIssuerJwksToRealmConfig(final JwtIssuerAndRealm jwtIssuerAndRealm) throws Exception { - if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.httpClient == null)) { + if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.jwkSetLoader.isFile())) { LOGGER.trace("Updating JwtRealm PKC public JWKSet local file"); final Path path = PathUtils.get(jwtIssuerAndRealm.realm.jwkSetPath); Files.writeString(path, jwtIssuerAndRealm.issuer.encodedJwkSetPkcPublic);