diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java index 899c6f454fc4d..0c20982f42061 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcTenantConfig.java @@ -505,7 +505,6 @@ public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) { @ConfigGroup public static class Jwks { - /** * If JWK verification keys should be fetched at the moment a connection to the OIDC provider * is initialized. @@ -539,6 +538,13 @@ public static class Jwks { @ConfigItem public Optional cleanUpTimerInterval = Optional.empty(); + /** + * In case there is no key identifier ('kid') or certificate thumbprints ('x5t', 'x5t#S256') specified in the JOSE + * header and no key could be determined, check all available keys matching the token algorithm ('alg') header value. + */ + @ConfigItem(defaultValue = "false") + public boolean tryAll = false; + public int getCacheSize() { return cacheSize; } @@ -570,6 +576,14 @@ public boolean isResolveEarly() { public void setResolveEarly(boolean resolveEarly) { this.resolveEarly = resolveEarly; } + + public boolean isTryAll() { + return tryAll; + } + + public void setTryAll(boolean fallbackToTryAll) { + this.tryAll = fallbackToTryAll; + } } @ConfigGroup diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java index a2a2d85a2ab96..9e408dc3068bc 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/DynamicVerificationKeyResolver.java @@ -32,10 +32,12 @@ public class DynamicVerificationKeyResolver { private final OidcProviderClient client; private final MemoryCache cache; + private final boolean tryAll; final CertChainPublicKeyResolver chainResolverFallback; public DynamicVerificationKeyResolver(OidcProviderClient client, OidcTenantConfig config) { this.client = client; + this.tryAll = config.jwks.tryAll; this.cache = new MemoryCache(client.getVertx(), config.jwks.cleanUpTimerInterval, config.jwks.cacheTimeToLive, config.jwks.cacheSize); if (config.certificateChain.trustStoreFile.isPresent()) { @@ -115,6 +117,12 @@ public Uni apply(JsonWebKeySet jwks) { newKey = jwks.getKeyWithoutKeyIdAndThumbprint("RSA"); } + // if (newKey == null && tryAll && kid == null && thumbprint == null) { + // LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set," + // + " falling back to trying all available keys"); + // newKey = jwks.findKeyInAllKeys(jws); // there is nothing to check the signature for in this method + // } + if (newKey == null && chainResolverFallback != null) { return getChainResolver(); } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java index dedfe32bf1156..cf6a8a0319ce3 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/JsonWebKeySet.java @@ -7,14 +7,18 @@ import java.util.Map; import java.util.Set; +import org.jboss.logging.Logger; import org.jose4j.jwk.JsonWebKey; import org.jose4j.jwk.PublicJsonWebKey; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.lang.InvalidAlgorithmException; import org.jose4j.lang.JoseException; import io.quarkus.oidc.OIDCException; public class JsonWebKeySet { + private static final Logger LOG = Logger.getLogger(JsonWebKeySet.class); private static final String RSA_KEY_TYPE = "RSA"; private static final String ELLIPTIC_CURVE_KEY_TYPE = "EC"; // This key type is used when EdDSA algorithm is used @@ -27,6 +31,7 @@ public class JsonWebKeySet { private Map keysWithThumbprints = new HashMap<>(); private Map keysWithS256Thumbprints = new HashMap<>(); private Map> keysWithoutKeyIdAndThumbprint = new HashMap<>(); + private Map> allKeys = new HashMap<>(); public JsonWebKeySet(String json) { initKeys(json); @@ -37,6 +42,8 @@ private void initKeys(String json) { org.jose4j.jwk.JsonWebKeySet jwkSet = new org.jose4j.jwk.JsonWebKeySet(json); for (JsonWebKey jwkKey : jwkSet.getJsonWebKeys()) { if (isSupportedJwkKey(jwkKey)) { + addKeyToListInMap(jwkKey, allKeys); + if (jwkKey.getKeyId() != null) { keysWithKeyId.put(jwkKey.getKeyId(), jwkKey.getKey()); } @@ -52,12 +59,7 @@ private void initKeys(String json) { keysWithS256Thumbprints.put(x5tS256, jwkKey.getKey()); } if (jwkKey.getKeyId() == null && x5t == null && x5tS256 == null && jwkKey.getKeyType() != null) { - List keys = keysWithoutKeyIdAndThumbprint.get(jwkKey.getKeyType()); - if (keys == null) { - keys = new ArrayList<>(); - keysWithoutKeyIdAndThumbprint.put(jwkKey.getKeyType(), keys); - } - keys.add(jwkKey.getKey()); + addKeyToListInMap(jwkKey, keysWithoutKeyIdAndThumbprint); } } } @@ -71,6 +73,48 @@ private static boolean isSupportedJwkKey(JsonWebKey jwkKey) { && (SIGNATURE_USE.equals(jwkKey.getUse()) || jwkKey.getUse() == null); } + private void addKeyToListInMap(JsonWebKey key, Map> map) { + List keys = map.get(key.getKeyType()); + + if (keys == null) { + keys = new ArrayList<>(); + map.put(key.getKeyType(), keys); + } + + keys.add(key.getKey()); + } + + public Key findKeyInAllKeys(JsonWebSignature jws) { + LOG.debug("Evaluating all keys to find a matching one"); + final Key initialKey = jws.getKey(); + final String keyType; + + try { + keyType = jws.getKeyType(); + } catch (InvalidAlgorithmException e) { + LOG.debug("No key type available, cannot determine keys to check", e); + return null; + } + + for (Key key : allKeys.getOrDefault(keyType, List.of())) { + jws.setKey(key); + + try { + if (jws.verifySignature()) { + jws.setKey(initialKey); + LOG.debugf("Found matching key %s", key.toString()); + return key; + } + } catch (JoseException e) { + LOG.debugf(e, "Verifying signature with key %s failed.", key.toString()); + } + } + + jws.setKey(initialKey); + LOG.debug("No matching key found"); + return null; + } + public Key getKeyWithId(String kid) { return keysWithKeyId.get(kid); } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java index dde4e3d77d34d..c49c8a6d87372 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java @@ -489,6 +489,12 @@ public Key resolveKey(JsonWebSignature jws, List nestingContex } } + if (key == null && oidcConfig.jwks.tryAll && kid == null && thumbprint == null) { + LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set," + + " falling back to trying all available keys"); + key = jwks.findKeyInAllKeys(jws); + } + if (key == null && chainResolverFallback != null) { LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set," + " falling back to the certificate chain resolver"); diff --git a/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java b/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java index fe13364f4e63b..773e8f8a01ac9 100644 --- a/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java +++ b/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java @@ -98,7 +98,45 @@ public void testTokenWithoutKidMultipleRSAJwkWithoutKid() throws Exception { } catch (InvalidJwtException ex) { assertTrue(ex.getCause() instanceof UnresolvableKeyException); } + } + } + + @Test + public void testTokenWithoutKidMultipleRSAJwkWithoutKidTryAll() throws Exception { + RsaJsonWebKey rsaJsonWebKey1 = RsaJwkGenerator.generateJwk(2048); + RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048); + JsonWebKeySet jwkSet = new JsonWebKeySet( + "{\"keys\": [" + rsaJsonWebKey1.toJson() + "," + rsaJsonWebKey2.toJson() + "]}"); + + final String token = Jwt.issuer("http://keycloak/realm").sign(rsaJsonWebKey2.getPrivateKey()); + final OidcTenantConfig config = new OidcTenantConfig(); + config.jwks.tryAll = true; + + try (OidcProvider provider = new OidcProvider(null, config, jwkSet, null)) { + TokenVerificationResult result = provider.verifyJwtToken(token, false, false, null); + assertEquals("http://keycloak/realm", result.localVerificationResult.getString("iss")); + } + } + @Test + public void testTokenWithoutKidMultipleRSAJwkWithoutKidTryAllNoMatching() throws Exception { + RsaJsonWebKey rsaJsonWebKey1 = RsaJwkGenerator.generateJwk(2048); + RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048); + RsaJsonWebKey rsaJsonWebKey3 = RsaJwkGenerator.generateJwk(2048); + JsonWebKeySet jwkSet = new JsonWebKeySet( + "{\"keys\": [" + rsaJsonWebKey1.toJson() + "," + rsaJsonWebKey2.toJson() + "]}"); + + final String token = Jwt.issuer("http://keycloak/realm").sign(rsaJsonWebKey3.getPrivateKey()); + final OidcTenantConfig config = new OidcTenantConfig(); + config.jwks.tryAll = true; + + try (OidcProvider provider = new OidcProvider(null, config, jwkSet, null)) { + try { + provider.verifyJwtToken(token, false, false, null); + fail("InvalidJwtException expected"); + } catch (InvalidJwtException ex) { + assertTrue(ex.getCause() instanceof UnresolvableKeyException); + } } }