Skip to content

Commit

Permalink
Fix racing when loading new JWKs from multiple threads (#88753)
Browse files Browse the repository at this point in the history
This PR ensures the mutation of JWKs is done in a single thread and
visible to all other threads, which in turn ensures validation to be
correctly performed concurrently.

Relates: #88023
  • Loading branch information
ywangd authored Jul 26, 2022
1 parent 631a705 commit d39836d
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
public class JwtRealm extends Realm implements CachingRealm, Releasable {
private static final Logger LOGGER = LogManager.getLogger(JwtRealm.class);

private static final ContentAndJwksAlgs EMPTY_CONTENT_AND_JWKS_ALGS = new ContentAndJwksAlgs(null, new JwksAlgs(List.of(), List.of()));

// Cached authenticated users, and adjusted JWT expiration date (=exp+skew) for checking if the JWT expired before the cache entry
record ExpiringUser(User user, Date exp) {
ExpiringUser {
Expand Down Expand Up @@ -109,7 +111,6 @@ boolean isEmpty() {
final boolean isConfiguredJwkSetPkc;
final boolean isConfiguredJwkSetHmac;
final boolean isConfiguredJwkOidcHmac;
private final CloseableHttpAsyncClient httpClient;
final JwkSetLoader jwkSetLoader;
final TimeValue allowedClockSkew;
final Boolean populateUserMetadata;
Expand All @@ -125,9 +126,7 @@ boolean isEmpty() {
final List<String> allowedJwksAlgsPkc;
final List<String> allowedJwksAlgsHmac;
DelegatedAuthorizationSupport delegatedAuthorizationSupport = null;
ContentAndJwksAlgs contentAndJwksAlgsPkc;
ContentAndJwksAlgs contentAndJwksAlgsHmac;
final URI jwkSetPathUri;

JwtRealm(
final RealmConfig realmConfig,
Expand Down Expand Up @@ -184,30 +183,17 @@ boolean isEmpty() {
);
}

if (this.isConfiguredJwkSetPkc) {
final URI jwkSetPathUri = JwtUtil.parseHttpsUri(jwkSetPath);
if (jwkSetPathUri == null) {
this.jwkSetPathUri = null; // local file path
this.httpClient = null;
} else {
this.jwkSetPathUri = jwkSetPathUri; // HTTPS URL
this.httpClient = JwtUtil.createHttpClient(this.config, sslService);
}
this.jwkSetLoader = new JwkSetLoader(); // PKC JWKSet loader for HTTPS URL or local file path
} else {
this.jwkSetPathUri = null; // not configured
this.httpClient = null;
this.jwkSetLoader = null;
}

// Any exception during loading requires closing JwkSetLoader's HTTP client to avoid a thread pool leak
try {
this.contentAndJwksAlgsHmac = this.parseJwksAlgsHmac();
this.contentAndJwksAlgsPkc = this.parseJwksAlgsPkc();
if (this.isConfiguredJwkSetPkc) {
this.jwkSetLoader = new JwkSetLoader(sslService); // PKC JWKSet loader for HTTPS URL or local file path
} else {
this.jwkSetLoader = null;
}
this.verifyAnyAvailableJwkAndAlgPair();
} catch (Throwable t) {
// ASSUME: Tests or startup only. Catch and rethrow Throwable here, in case some code throws an uncaught RuntimeException.
this.close();
close();
throw t;
}
}
Expand Down Expand Up @@ -255,14 +241,22 @@ private ContentAndJwksAlgs parseJwksAlgsHmac() {
return new ContentAndJwksAlgs(hmacStringContentsSha256, jwksAlgsHmac);
}

private ContentAndJwksAlgs parseJwksAlgsPkc() {
// Package private for test
URI getJwkSetPathUri() {
if (jwkSetLoader != null) {
return jwkSetLoader.jwkSetPathUri;
} else {
return null;
}
}

// Package private for test
ContentAndJwksAlgs getJwksAlgsPkc() {
if (this.isConfiguredJwkSetPkc == false) {
return new ContentAndJwksAlgs(null, new JwksAlgs(Collections.emptyList(), Collections.emptyList()));
return EMPTY_CONTENT_AND_JWKS_ALGS;
} else {
// ASSUME: Blocking read operations are OK during startup
final PlainActionFuture<ContentAndJwksAlgs> future = new PlainActionFuture<>();
this.jwkSetLoader.load(future);
return future.actionGet();
assert jwkSetLoader != null;
return jwkSetLoader.contentAndJwksAlgsPkc;
}
}

Expand All @@ -277,8 +271,8 @@ private Cache<BytesArray, ExpiringUser> buildJwtCache() {

private void verifyAnyAvailableJwkAndAlgPair() {
assert this.contentAndJwksAlgsHmac != null : "HMAC not initialized";
assert this.contentAndJwksAlgsPkc != null : "PKC not initialized";
if (this.contentAndJwksAlgsHmac.jwksAlgs.isEmpty() && this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) {
assert getJwksAlgsPkc() != null : "PKC not initialized";
if (this.contentAndJwksAlgsHmac.jwksAlgs.isEmpty() && this.getJwksAlgsPkc().jwksAlgs.isEmpty()) {
final String msg = "No available JWK and algorithm for HMAC or PKC. Realm authentication expected to fail until this is fixed.";
throw new SettingsException(msg);
}
Expand Down Expand Up @@ -312,7 +306,9 @@ public void initialize(final Iterable<Realm> allRealms, final XPackLicenseState
@Override
public void close() {
this.invalidateJwtCache();
this.closeHttpClient();
if (jwkSetLoader != null) {
jwkSetLoader.close();
}
}

/**
Expand All @@ -332,19 +328,6 @@ private void invalidateJwtCache() {
}
}

/**
* Clean up HTTPS client cache (if enabled).
*/
private void closeHttpClient() {
if (this.httpClient != null) {
try {
this.httpClient.close();
} catch (IOException e) {
LOGGER.warn(() -> "Exception closing HTTPS client for realm [" + super.name() + "]", e);
}
}
}

@Override
public void lookupUser(final String username, final ActionListener<User> listener) {
this.ensureInitialized();
Expand Down Expand Up @@ -596,7 +579,7 @@ private void validateSignature(
try {
JwtValidateUtil.validateSignature(
jwt,
isJwtAlgHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.contentAndJwksAlgsPkc.jwksAlgs.jwks
isJwtAlgHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.getJwksAlgsPkc().jwksAlgs.jwks
);
listener.onResponse(null);
} catch (Exception primaryException) {
Expand All @@ -609,33 +592,32 @@ private void validateSignature(
() -> org.elasticsearch.core.Strings.format(
"Signature verification failed for [%s] reloading JWKSet (was: #[%s] JWKs, #[%s] algs, sha256=[%s])",
tokenPrincipal,
this.contentAndJwksAlgsPkc.jwksAlgs.jwks().size(),
this.contentAndJwksAlgsPkc.jwksAlgs.algs().size(),
MessageDigests.toHexString(this.contentAndJwksAlgsPkc.sha256())
this.getJwksAlgsPkc().jwksAlgs.jwks().size(),
this.getJwksAlgsPkc().jwksAlgs.algs().size(),
MessageDigests.toHexString(this.getJwksAlgsPkc().sha256())
),
primaryException
);

this.jwkSetLoader.load(ActionListener.wrap(newContentAndJwksAlgs -> {
if (Arrays.equals(this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) {
this.jwkSetLoader.reload(ActionListener.wrap(isUpdated -> {
if (false == isUpdated) {
// No change in JWKSet
logger.debug("Reloaded same PKC JWKs, can't retry verify JWT token=[{}]", tokenPrincipal);
listener.onFailure(primaryException);
return;
}
this.contentAndJwksAlgsPkc = newContentAndJwksAlgs;
// If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated.
// Enhancement idea: Use separate caches for PKC vs HMAC JWKs, so only PKC entries get invalidated.
// Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs.
this.invalidateJwtCache();

if (this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) {
if (this.getJwksAlgsPkc().jwksAlgs.isEmpty()) {
logger.debug("Reloaded empty PKC JWKs, verification of JWT token will fail [{}]", tokenPrincipal);
// fall through and let try/catch below handle empty JWKs failure log and response
}

try {
JwtValidateUtil.validateSignature(jwt, this.contentAndJwksAlgsPkc.jwksAlgs.jwks);
JwtValidateUtil.validateSignature(jwt, this.getJwksAlgsPkc().jwksAlgs.jwks);
listener.onResponse(null);
} catch (Exception secondaryException) {
logger.debug(
Expand All @@ -660,14 +642,73 @@ public void usageStats(final ActionListener<Map<String, Object>> listener) {
}, listener::onFailure));
}

private class JwkSetLoader {
private class JwkSetLoader implements Releasable {
private final AtomicReference<ListenableFuture<ContentAndJwksAlgs>> reloadFutureRef = new AtomicReference<>();
private final URI jwkSetPathUri;
private final CloseableHttpAsyncClient httpClient;
private volatile ContentAndJwksAlgs contentAndJwksAlgsPkc;

JwkSetLoader(final SSLService sslService) {
assert JwtRealm.this.isConfiguredJwkSetPkc;
final URI jwkSetPathUri = JwtUtil.parseHttpsUri(jwkSetPath);
if (jwkSetPathUri == null) {
this.jwkSetPathUri = null; // local file path
this.httpClient = null;
} else {
this.jwkSetPathUri = jwkSetPathUri; // HTTPS URL
this.httpClient = JwtUtil.createHttpClient(JwtRealm.this.config, sslService);
}
// Any exception during loading requires closing JwkSetLoader's HTTP client to avoid a thread pool leak
try {
final PlainActionFuture<ContentAndJwksAlgs> future = new PlainActionFuture<>();
load(future);
// ASSUME: Blocking read operations are OK during startup
contentAndJwksAlgsPkc = future.actionGet();
} catch (Throwable t) {
close();
throw t;
}
}

/**
* Clean up HTTPS client cache (if enabled).
*/
@Override
public void close() {
if (httpClient != null) {
try {
httpClient.close();
} catch (IOException e) {
LOGGER.warn(() -> "Exception closing HTTPS client for realm [" + JwtRealm.this.name() + "]", e);
}
}
}

/**
* Load the JWK sets and pass its content to the specified listener.
*/
void load(final ActionListener<ContentAndJwksAlgs> listener) {
final ListenableFuture<ContentAndJwksAlgs> future = this.getFuture();
future.addListener(listener);
}

/**
* Reload the JWK sets, compare to existing JWK sets and update it to the reloaded value if
* they are different. The listener is called with false if the reloaded content is the same
* as the existing one or true if they are different.
*/
void reload(final ActionListener<Boolean> listener) {
load(ActionListener.wrap(newContentAndJwksAlgs -> {
if (Arrays.equals(contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) {
// No change in JWKSet
listener.onResponse(false);
} else {
contentAndJwksAlgsPkc = newContentAndJwksAlgs;
listener.onResponse(true);
}
}, listener::onFailure));
}

private ListenableFuture<ContentAndJwksAlgs> getFuture() {
for (;;) {
final ListenableFuture<ContentAndJwksAlgs> existingFuture = this.reloadFutureRef.get();
Expand All @@ -677,7 +718,10 @@ private ListenableFuture<ContentAndJwksAlgs> getFuture() {

final ListenableFuture<ContentAndJwksAlgs> newFuture = new ListenableFuture<>();
if (this.reloadFutureRef.compareAndSet(null, newFuture)) {
loadInternal(ActionListener.runAfter(newFuture, () -> this.reloadFutureRef.compareAndSet(newFuture, null)));
loadInternal(ActionListener.runAfter(newFuture, () -> {
final ListenableFuture<ContentAndJwksAlgs> oldValue = this.reloadFutureRef.getAndSet(null);
assert oldValue == newFuture : "future reference changed unexpectedly";
}));
return newFuture;
}
// else, Another thread set the future-ref before us, just try it all again
Expand All @@ -686,7 +730,7 @@ private ListenableFuture<ContentAndJwksAlgs> getFuture() {

private void loadInternal(final ActionListener<ContentAndJwksAlgs> listener) {
// PKC JWKSet get contents from local file or remote HTTPS URL
if (JwtRealm.this.httpClient == null) {
if (httpClient == null) {
LOGGER.trace("Loading PKC JWKs from path [{}]", JwtRealm.this.jwkSetPath);
listener.onResponse(
this.parseContent(
Expand All @@ -698,13 +742,13 @@ private void loadInternal(final ActionListener<ContentAndJwksAlgs> listener) {
)
);
} else {
LOGGER.trace("Loading PKC JWKs from https URI [{}]", JwtRealm.this.jwkSetPathUri);
LOGGER.trace("Loading PKC JWKs from https URI [{}]", jwkSetPathUri);
JwtUtil.readUriContents(
RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH),
JwtRealm.this.jwkSetPathUri,
JwtRealm.this.httpClient,
jwkSetPathUri,
httpClient,
listener.map(bytes -> {
LOGGER.trace("Loaded bytes [{}] from [{}]", bytes.length, JwtRealm.this.jwkSetPathUri);
LOGGER.trace("Loaded bytes [{}] from [{}]", bytes.length, jwkSetPathUri);
return this.parseContent(bytes);
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,10 @@ public void testJwtValidationFailures() throws Exception {
{ // Verify rejection of a tampered header (flip HMAC=>RSA or RSA/EC=>HMAC)
final String mixupAlg; // Check if there are any algorithms available in the realm for attempting a flip test
if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(validHeader.getAlgorithm().getName())) {
if (jwtIssuerAndRealm.realm().contentAndJwksAlgsPkc.jwksAlgs().algs().isEmpty()) {
if (jwtIssuerAndRealm.realm().getJwksAlgsPkc().jwksAlgs().algs().isEmpty()) {
mixupAlg = null; // cannot flip HMAC to PKC (no PKC algs available)
} else {
mixupAlg = randomFrom(jwtIssuerAndRealm.realm().contentAndJwksAlgsPkc.jwksAlgs().algs()); // flip HMAC to PKC
mixupAlg = randomFrom(jwtIssuerAndRealm.realm().getJwksAlgsPkc().jwksAlgs().algs()); // flip HMAC to PKC
}
} else {
if (jwtIssuerAndRealm.realm().contentAndJwksAlgsHmac.jwksAlgs().algs().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ protected JwtIssuer createJwtIssuer(
}

protected void copyIssuerJwksToRealmConfig(final JwtIssuerAndRealm jwtIssuerAndRealm) throws Exception {
if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.jwkSetPathUri == null)) {
if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.getJwkSetPathUri() == null)) {
LOGGER.trace("Updating JwtRealm PKC public JWKSet local file");
final Path path = PathUtils.get(jwtIssuerAndRealm.realm.jwkSetPath);
Files.writeString(path, jwtIssuerAndRealm.issuer.encodedJwkSetPkcPublic);
Expand Down Expand Up @@ -659,7 +659,7 @@ protected void printJwtRealm(final JwtRealm jwtRealm) {
+ ", algsPkc="
+ jwtRealm.allowedJwksAlgsPkc
+ ", filteredPkc="
+ jwtRealm.contentAndJwksAlgsPkc.jwksAlgs().algs()
+ jwtRealm.getJwksAlgsPkc().jwksAlgs().algs()
+ ", claimPrincipal=["
+ jwtRealm.claimParserPrincipal.getClaimName()
+ "], claimGroups=["
Expand All @@ -675,7 +675,7 @@ protected void printJwtRealm(final JwtRealm jwtRealm) {
for (final JWK jwk : jwtRealm.contentAndJwksAlgsHmac.jwksAlgs().jwks()) {
LOGGER.info("REALM HMAC: jwk=[{}]", jwk);
}
for (final JWK jwk : jwtRealm.contentAndJwksAlgsPkc.jwksAlgs().jwks()) {
for (final JWK jwk : jwtRealm.getJwksAlgsPkc().jwksAlgs().jwks()) {
LOGGER.info("REALM PKC: jwk=[{}]", jwk);
}
}
Expand Down

0 comments on commit d39836d

Please sign in to comment.