Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to try all OIDC JWK keys as fallback #42008

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,6 @@ public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) {

@ConfigGroup
public static class Jwks {

/**
* If JWK verification keys should be fetched at the moment a connection to the OIDC provider
* is initialized.
Expand Down Expand Up @@ -539,6 +538,13 @@ public static class Jwks {
@ConfigItem
public Optional<Duration> cleanUpTimerInterval = Optional.empty();

/**
* In case there is no key identifier ('kid') or certificate thumbprints ('x5t', 'x5t#S256') specified in the JOSE
* header and no key could be determined, check all available keys matching the token algorithm ('alg') header value.
*/
@ConfigItem(defaultValue = "false")
public boolean tryAll = false;

public int getCacheSize() {
return cacheSize;
}
Expand Down Expand Up @@ -570,6 +576,14 @@ public boolean isResolveEarly() {
public void setResolveEarly(boolean resolveEarly) {
this.resolveEarly = resolveEarly;
}

public boolean isTryAll() {
return tryAll;
}

public void setTryAll(boolean fallbackToTryAll) {
this.tryAll = fallbackToTryAll;
}
}

@ConfigGroup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ public class DynamicVerificationKeyResolver {

private final OidcProviderClient client;
private final MemoryCache<Key> cache;
private final boolean tryAll;
final CertChainPublicKeyResolver chainResolverFallback;

public DynamicVerificationKeyResolver(OidcProviderClient client, OidcTenantConfig config) {
this.client = client;
this.tryAll = config.jwks.tryAll;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable is never used

this.cache = new MemoryCache<Key>(client.getVertx(), config.jwks.cleanUpTimerInterval,
config.jwks.cacheTimeToLive, config.jwks.cacheSize);
if (config.certificateChain.trustStoreFile.isPresent()) {
Expand Down Expand Up @@ -115,6 +117,12 @@ public Uni<? extends VerificationKeyResolver> apply(JsonWebKeySet jwks) {
newKey = jwks.getKeyWithoutKeyIdAndThumbprint("RSA");
}

// if (newKey == null && tryAll && kid == null && thumbprint == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not useful to have this commented out code.

// LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set,"
// + " falling back to trying all available keys");
// newKey = jwks.findKeyInAllKeys(jws); // there is nothing to check the signature for in this method
// }

if (newKey == null && chainResolverFallback != null) {
return getChainResolver();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
import java.util.Map;
import java.util.Set;

import org.jboss.logging.Logger;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.PublicJsonWebKey;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.lang.InvalidAlgorithmException;
import org.jose4j.lang.JoseException;

import io.quarkus.oidc.OIDCException;

public class JsonWebKeySet {

private static final Logger LOG = Logger.getLogger(JsonWebKeySet.class);
private static final String RSA_KEY_TYPE = "RSA";
private static final String ELLIPTIC_CURVE_KEY_TYPE = "EC";
// This key type is used when EdDSA algorithm is used
Expand All @@ -27,6 +31,7 @@ public class JsonWebKeySet {
private Map<String, Key> keysWithThumbprints = new HashMap<>();
private Map<String, Key> keysWithS256Thumbprints = new HashMap<>();
private Map<String, List<Key>> keysWithoutKeyIdAndThumbprint = new HashMap<>();
private Map<String, List<Key>> allKeys = new HashMap<>();

public JsonWebKeySet(String json) {
initKeys(json);
Expand All @@ -37,6 +42,8 @@ private void initKeys(String json) {
org.jose4j.jwk.JsonWebKeySet jwkSet = new org.jose4j.jwk.JsonWebKeySet(json);
for (JsonWebKey jwkKey : jwkSet.getJsonWebKeys()) {
if (isSupportedJwkKey(jwkKey)) {
addKeyToListInMap(jwkKey, allKeys);

if (jwkKey.getKeyId() != null) {
keysWithKeyId.put(jwkKey.getKeyId(), jwkKey.getKey());
}
Expand All @@ -52,12 +59,7 @@ private void initKeys(String json) {
keysWithS256Thumbprints.put(x5tS256, jwkKey.getKey());
}
if (jwkKey.getKeyId() == null && x5t == null && x5tS256 == null && jwkKey.getKeyType() != null) {
List<Key> keys = keysWithoutKeyIdAndThumbprint.get(jwkKey.getKeyType());
if (keys == null) {
keys = new ArrayList<>();
keysWithoutKeyIdAndThumbprint.put(jwkKey.getKeyType(), keys);
}
keys.add(jwkKey.getKey());
addKeyToListInMap(jwkKey, keysWithoutKeyIdAndThumbprint);
}
}
}
Expand All @@ -71,6 +73,48 @@ private static boolean isSupportedJwkKey(JsonWebKey jwkKey) {
&& (SIGNATURE_USE.equals(jwkKey.getUse()) || jwkKey.getUse() == null);
}

private void addKeyToListInMap(JsonWebKey key, Map<String, List<Key>> map) {
List<Key> keys = map.get(key.getKeyType());

if (keys == null) {
keys = new ArrayList<>();
map.put(key.getKeyType(), keys);
}

keys.add(key.getKey());
}

public Key findKeyInAllKeys(JsonWebSignature jws) {
LOG.debug("Evaluating all keys to find a matching one");
final Key initialKey = jws.getKey();
final String keyType;

try {
keyType = jws.getKeyType();
} catch (InvalidAlgorithmException e) {
LOG.debug("No key type available, cannot determine keys to check", e);
return null;
}

for (Key key : allKeys.getOrDefault(keyType, List.of())) {
jws.setKey(key);

try {
if (jws.verifySignature()) {
jws.setKey(initialKey);
LOG.debugf("Found matching key %s", key.toString());
return key;
}
} catch (JoseException e) {
LOG.debugf(e, "Verifying signature with key %s failed.", key.toString());
}
}

jws.setKey(initialKey);
LOG.debug("No matching key found");
return null;
}

public Key getKeyWithId(String kid) {
return keysWithKeyId.get(kid);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ public Key resolveKey(JsonWebSignature jws, List<JsonWebStructure> nestingContex
}
}

if (key == null && oidcConfig.jwks.tryAll && kid == null && thumbprint == null) {
LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set,"
+ " falling back to trying all available keys");
key = jwks.findKeyInAllKeys(jws);
}

if (key == null && chainResolverFallback != null) {
LOG.debug("JWK is not available, neither 'kid' nor 'x5t#S256' nor 'x5t' token headers are set,"
+ " falling back to the certificate chain resolver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,45 @@ public void testTokenWithoutKidMultipleRSAJwkWithoutKid() throws Exception {
} catch (InvalidJwtException ex) {
assertTrue(ex.getCause() instanceof UnresolvableKeyException);
}
}
}

@Test
public void testTokenWithoutKidMultipleRSAJwkWithoutKidTryAll() throws Exception {
RsaJsonWebKey rsaJsonWebKey1 = RsaJwkGenerator.generateJwk(2048);
RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048);
JsonWebKeySet jwkSet = new JsonWebKeySet(
"{\"keys\": [" + rsaJsonWebKey1.toJson() + "," + rsaJsonWebKey2.toJson() + "]}");

final String token = Jwt.issuer("http://keycloak/realm").sign(rsaJsonWebKey2.getPrivateKey());
final OidcTenantConfig config = new OidcTenantConfig();
config.jwks.tryAll = true;

try (OidcProvider provider = new OidcProvider(null, config, jwkSet, null)) {
TokenVerificationResult result = provider.verifyJwtToken(token, false, false, null);
assertEquals("http://keycloak/realm", result.localVerificationResult.getString("iss"));
}
}

@Test
public void testTokenWithoutKidMultipleRSAJwkWithoutKidTryAllNoMatching() throws Exception {
RsaJsonWebKey rsaJsonWebKey1 = RsaJwkGenerator.generateJwk(2048);
RsaJsonWebKey rsaJsonWebKey2 = RsaJwkGenerator.generateJwk(2048);
RsaJsonWebKey rsaJsonWebKey3 = RsaJwkGenerator.generateJwk(2048);
JsonWebKeySet jwkSet = new JsonWebKeySet(
"{\"keys\": [" + rsaJsonWebKey1.toJson() + "," + rsaJsonWebKey2.toJson() + "]}");

final String token = Jwt.issuer("http://keycloak/realm").sign(rsaJsonWebKey3.getPrivateKey());
final OidcTenantConfig config = new OidcTenantConfig();
config.jwks.tryAll = true;

try (OidcProvider provider = new OidcProvider(null, config, jwkSet, null)) {
try {
provider.verifyJwtToken(token, false, false, null);
fail("InvalidJwtException expected");
} catch (InvalidJwtException ex) {
assertTrue(ex.getCause() instanceof UnresolvableKeyException);
}
}
}

Expand Down