Skip to content

Commit

Permalink
add option to try all OIDC JWKs as fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
c15yi committed Jul 25, 2024
1 parent f6b5380 commit bbdfb06
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,6 @@ public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) {

@ConfigGroup
public static class Jwks {

/**
* If JWK verification keys should be fetched at the moment a connection to the OIDC provider
* is initialized.
Expand Down Expand Up @@ -539,6 +538,13 @@ public static class Jwks {
@ConfigItem
public Optional<Duration> cleanUpTimerInterval = Optional.empty();

/**
* In case there is no key identifier ('kid') or certificate thumbprints ('x5t', 'x5t#S256') specified in the JOSE
* header and no key could be determined, check all available keys matching the token algorithm ('alg') header value.
*/
@ConfigItem(defaultValue = "false")
public boolean tryAll = false;

public int getCacheSize() {
return cacheSize;
}
Expand Down Expand Up @@ -570,6 +576,14 @@ public boolean isResolveEarly() {
public void setResolveEarly(boolean resolveEarly) {
this.resolveEarly = resolveEarly;
}

public boolean isTryAll() {
return tryAll;
}

public void setTryAll(boolean fallbackToTryAll) {
this.tryAll = fallbackToTryAll;
}
}

@ConfigGroup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ public class DynamicVerificationKeyResolver {

private final OidcProviderClient client;
private final MemoryCache<Key> cache;
private final boolean tryAll;
final CertChainPublicKeyResolver chainResolverFallback;

public DynamicVerificationKeyResolver(OidcProviderClient client, OidcTenantConfig config) {
this.client = client;
this.tryAll = config.jwks.tryAll;
this.cache = new MemoryCache<Key>(client.getVertx(), config.jwks.cleanUpTimerInterval,
config.jwks.cacheTimeToLive, config.jwks.cacheSize);
if (config.certificateChain.trustStoreFile.isPresent()) {
Expand Down Expand Up @@ -115,6 +117,12 @@ public Uni<? extends VerificationKeyResolver> apply(JsonWebKeySet jwks) {
newKey = jwks.getKeyWithoutKeyIdAndThumbprint("RSA");
}

// if (newKey == null && tryAll && kid == null && thumbprint == null) {
// LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set,"
// + " falling back to trying all available keys");
// newKey = jwks.findKeyInAllKeys(jws); // there is nothing to check the signature for in this method
// }

if (newKey == null && chainResolverFallback != null) {
return getChainResolver();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
import java.util.Map;
import java.util.Set;

import org.jboss.logging.Logger;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.PublicJsonWebKey;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.lang.InvalidAlgorithmException;
import org.jose4j.lang.JoseException;

import io.quarkus.oidc.OIDCException;

public class JsonWebKeySet {

private static final Logger LOG = Logger.getLogger(JsonWebKeySet.class);
private static final String RSA_KEY_TYPE = "RSA";
private static final String ELLIPTIC_CURVE_KEY_TYPE = "EC";
// This key type is used when EdDSA algorithm is used
Expand All @@ -27,6 +31,7 @@ public class JsonWebKeySet {
private Map<String, Key> keysWithThumbprints = new HashMap<>();
private Map<String, Key> keysWithS256Thumbprints = new HashMap<>();
private Map<String, List<Key>> keysWithoutKeyIdAndThumbprint = new HashMap<>();
private Map<String, List<Key>> allKeys = new HashMap<>();

public JsonWebKeySet(String json) {
initKeys(json);
Expand All @@ -37,6 +42,8 @@ private void initKeys(String json) {
org.jose4j.jwk.JsonWebKeySet jwkSet = new org.jose4j.jwk.JsonWebKeySet(json);
for (JsonWebKey jwkKey : jwkSet.getJsonWebKeys()) {
if (isSupportedJwkKey(jwkKey)) {
addKeyToListInMap(jwkKey, allKeys);

if (jwkKey.getKeyId() != null) {
keysWithKeyId.put(jwkKey.getKeyId(), jwkKey.getKey());
}
Expand All @@ -52,12 +59,7 @@ private void initKeys(String json) {
keysWithS256Thumbprints.put(x5tS256, jwkKey.getKey());
}
if (jwkKey.getKeyId() == null && x5t == null && x5tS256 == null && jwkKey.getKeyType() != null) {
List<Key> keys = keysWithoutKeyIdAndThumbprint.get(jwkKey.getKeyType());
if (keys == null) {
keys = new ArrayList<>();
keysWithoutKeyIdAndThumbprint.put(jwkKey.getKeyType(), keys);
}
keys.add(jwkKey.getKey());
addKeyToListInMap(jwkKey, keysWithoutKeyIdAndThumbprint);
}
}
}
Expand All @@ -71,6 +73,48 @@ private static boolean isSupportedJwkKey(JsonWebKey jwkKey) {
&& (SIGNATURE_USE.equals(jwkKey.getUse()) || jwkKey.getUse() == null);
}

private void addKeyToListInMap(JsonWebKey key, Map<String, List<Key>> map) {
List<Key> keys = map.get(key.getKeyType());

if (keys == null) {
keys = new ArrayList<>();
map.put(key.getKeyType(), keys);
}

keys.add(key.getKey());
}

public Key findKeyInAllKeys(JsonWebSignature jws) {
LOG.debug("Evaluating all keys to find a matching one");
final Key initialKey = jws.getKey();
final String keyType;

try {
keyType = jws.getKeyType();
} catch (InvalidAlgorithmException e) {
LOG.debug("No key type available, cannot determine keys to check", e);
return null;
}

for (Key key : allKeys.getOrDefault(keyType, List.of())) {
jws.setKey(key);

try {
if (jws.verifySignature()) {
jws.setKey(initialKey);
LOG.debugf("Found matching key %s", key.toString());
return key;
}
} catch (JoseException e) {
LOG.debugf(e, "Verifying signature with key %s failed.", key.toString());
}
}

jws.setKey(initialKey);
LOG.debug("No matching key found");
return null;
}

public Key getKeyWithId(String kid) {
return keysWithKeyId.get(kid);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.jose4j.jwx.JsonWebStructure;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import org.jose4j.lang.InvalidAlgorithmException;
import org.jose4j.lang.JoseException;
import org.jose4j.lang.UnresolvableKeyException;

import io.quarkus.logging.Log;
Expand Down Expand Up @@ -489,6 +490,12 @@ public Key resolveKey(JsonWebSignature jws, List<JsonWebStructure> nestingContex
}
}

if (key == null && oidcConfig.jwks.tryAll && kid == null && thumbprint == null) {
LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set,"
+ " falling back to trying all available keys");
key = jwks.findKeyInAllKeys(jws);
}

if (key == null && chainResolverFallback != null) {
LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set,"
+ " falling back to the certificate chain resolver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,45 @@ public void testTokenWithoutKidMultipleRSAJwkWithoutKid() throws Exception {
} catch (InvalidJwtException ex) {
assertTrue(ex.getCause() instanceof UnresolvableKeyException);
}
}
}

@Test
public void testTokenWithoutKidMultipleRSAJwkWithoutKidTryAll() throws Exception {
RsaJsonWebKey rsaJsonWebKey1 = RsaJwkGenerator.generateJwk(2048);
RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048);
JsonWebKeySet jwkSet = new JsonWebKeySet(
"{\"keys\": [" + rsaJsonWebKey1.toJson() + "," + rsaJsonWebKey2.toJson() + "]}");

final String token = Jwt.issuer("http://keycloak/realm").sign(rsaJsonWebKey2.getPrivateKey());
final OidcTenantConfig config = new OidcTenantConfig();
config.jwks.tryAll = true;

try (OidcProvider provider = new OidcProvider(null, config, jwkSet, null)) {
TokenVerificationResult result = provider.verifyJwtToken(token, false, false, null);
assertEquals("http://keycloak/realm", result.localVerificationResult.getString("iss"));
}
}

@Test
public void testTokenWithoutKidMultipleRSAJwkWithoutKidTryAllNoMatching() throws Exception {
RsaJsonWebKey rsaJsonWebKey1 = RsaJwkGenerator.generateJwk(2048);
RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048);
RsaJsonWebKey rsaJsonWebKey3 = RsaJwkGenerator.generateJwk(2048);
JsonWebKeySet jwkSet = new JsonWebKeySet(
"{\"keys\": [" + rsaJsonWebKey1.toJson() + "," + rsaJsonWebKey2.toJson() + "]}");

final String token = Jwt.issuer("http://keycloak/realm").sign(rsaJsonWebKey3.getPrivateKey());
final OidcTenantConfig config = new OidcTenantConfig();
config.jwks.tryAll = true;

try (OidcProvider provider = new OidcProvider(null, config, jwkSet, null)) {
try {
provider.verifyJwtToken(token, false, false, null);
fail("InvalidJwtException expected");
} catch (InvalidJwtException ex) {
assertTrue(ex.getCause() instanceof UnresolvableKeyException);
}
}
}

Expand Down

0 comments on commit bbdfb06

Please sign in to comment.