Skip to content

Commit

Permalink
Merge pull request quarkusio#26566 from sberyozkin/encrypted_id_token
Browse files Browse the repository at this point in the history
Support IdTokens which are returned encrypted in the code exchange response
  • Loading branch information
sberyozkin authored Jul 10, 2022
2 parents 60fd34f + 3fa08ec commit 1d5e7bc
Show file tree
Hide file tree
Showing 14 changed files with 282 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,20 @@ public static Token fromAudience(String... audience) {
@ConfigItem
public Optional<String> header = Optional.empty();

/**
* Decryption key location.
* JWT tokens can be inner-signed and encrypted by OpenId Connect providers.
* However, it is not always possible to remotely introspect such tokens because
* the providers may not control the private decryption keys.
* In such cases set this property to point to the file containing the decryption private key in
* PEM or JSON Web Key (JWK) format.
* Note that if a 'private_key_jwt' client authentication method is used then the private key
* which is used to sign client authentication JWT tokens will be used to try to decrypt an encrypted ID token
* if this property is not set.
*/
@ConfigItem
public Optional<String> decryptionKeyLocation = Optional.empty();

/**
* Allow the remote introspection of JWT tokens when no matching JWK key is available.
*
Expand Down Expand Up @@ -1102,6 +1116,14 @@ public Optional<Duration> getAge() {
public void setAge(Duration age) {
this.age = Optional.of(age);
}

public Optional<String> getDecryptionKeyLocation() {
return decryptionKeyLocation;
}

public void setDecryptionKeyLocation(String decryptionKeyLocation) {
this.decryptionKeyLocation = Optional.of(decryptionKeyLocation);
}
}

public static enum ApplicationType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import org.jboss.logging.Logger;
import org.jose4j.jwt.consumer.ErrorCodes;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.lang.JoseException;

import io.netty.handler.codec.http.HttpResponseStatus;
import io.quarkus.logging.Log;
import io.quarkus.oidc.AuthorizationCodeTokens;
import io.quarkus.oidc.IdTokenCredential;
import io.quarkus.oidc.OidcTenantConfig;
Expand All @@ -39,6 +41,7 @@
import io.quarkus.security.identity.IdentityProviderManager;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.vertx.http.runtime.security.ChallengeData;
import io.smallrye.jwt.algorithm.KeyEncryptionAlgorithm;
import io.smallrye.jwt.build.Jwt;
import io.smallrye.jwt.build.JwtClaimsBuilder;
import io.smallrye.jwt.util.KeyUtils;
Expand Down Expand Up @@ -216,7 +219,7 @@ public Uni<SecurityIdentity> apply(Void t) {
context.put(OidcConstants.ACCESS_TOKEN_VALUE, session.getAccessToken());
context.put(AuthorizationCodeTokens.class.getName(), session);
return authenticate(identityProviderManager, context,
new IdTokenCredential(session.getIdToken(),
new IdTokenCredential(decryptIdTokenIfEncryptedByProvider(configContext, session.getIdToken()),
isInternalIdToken(session.getIdToken(), configContext)))
.call(new Function<SecurityIdentity, Uni<?>>() {
@Override
Expand Down Expand Up @@ -268,6 +271,21 @@ public Uni<? extends SecurityIdentity> apply(Throwable t) {
});
}

private static String decryptIdTokenIfEncryptedByProvider(TenantConfigContext resolvedContext, String token) {
if ((resolvedContext.provider.tokenDecryptionKey != null || resolvedContext.provider.client.getClientJwtKey() != null)
&& OidcUtils.isEncryptedToken(token)) {
try {
return OidcUtils.decryptString(token,
resolvedContext.provider.tokenDecryptionKey != null ? resolvedContext.provider.tokenDecryptionKey
: resolvedContext.provider.client.getClientJwtKey(),
KeyEncryptionAlgorithm.RSA_OAEP);
} catch (JoseException ex) {
Log.debugf("Failed to decrypt a token: %s, a token introspection will be attempted instead", ex.getMessage());
}
}
return token;
}

private boolean isBackChannelLogoutPendingAndValid(TenantConfigContext configContext, String idToken) {
TokenVerificationResult backChannelLogoutTokenResult = resolver.getBackChannelLogoutTokens()
.remove(configContext.oidcConfig.getTenantId().get());
Expand Down Expand Up @@ -523,8 +541,10 @@ public Uni<SecurityIdentity> apply(final AuthorizationCodeTokens tokens, final T
context.put(OidcConstants.ACCESS_TOKEN_VALUE, tokens.getAccessToken());
context.put(AuthorizationCodeTokens.class.getName(), tokens);

final String idToken = decryptIdTokenIfEncryptedByProvider(configContext, tokens.getIdToken());

return authenticate(identityProviderManager, context,
new IdTokenCredential(tokens.getIdToken(), internalIdToken))
new IdTokenCredential(idToken, internalIdToken))
.call(new Function<SecurityIdentity, Uni<?>>() {
@Override
public Uni<Void> apply(SecurityIdentity identity) {
Expand All @@ -534,7 +554,7 @@ public Uni<Void> apply(SecurityIdentity identity) {
identity.getAttribute(OidcUtils.USER_INFO_ATTRIBUTE)));
}
return processSuccessfulAuthentication(context, configContext,
tokens, identity);
tokens, idToken, identity);
}
})
.map(new Function<SecurityIdentity, SecurityIdentity>() {
Expand Down Expand Up @@ -619,19 +639,20 @@ private String generateInternalIdToken(OidcTenantConfig oidcConfig, UserInfo use
private Uni<Void> processSuccessfulAuthentication(RoutingContext context,
TenantConfigContext configContext,
AuthorizationCodeTokens tokens,
String idToken,
SecurityIdentity securityIdentity) {
return removeSessionCookie(context, configContext.oidcConfig)
.chain(new Function<Void, Uni<? extends Void>>() {

@Override
public Uni<? extends Void> apply(Void t) {
JsonObject idToken = OidcUtils.decodeJwtContent(tokens.getIdToken());
JsonObject idTokenJson = OidcUtils.decodeJwtContent(idToken);

if (!idToken.containsKey("exp") || !idToken.containsKey("iat")) {
if (!idTokenJson.containsKey("exp") || !idTokenJson.containsKey("iat")) {
LOG.debug("ID Token is required to contain 'exp' and 'iat' claims");
throw new AuthenticationCompletionException();
}
long maxAge = idToken.getLong("exp") - idToken.getLong("iat");
long maxAge = idTokenJson.getLong("exp") - idTokenJson.getLong("iat");
if (configContext.oidcConfig.token.lifespanGrace.isPresent()) {
maxAge += configContext.oidcConfig.token.lifespanGrace.getAsInt();
}
Expand Down Expand Up @@ -824,14 +845,16 @@ public Uni<SecurityIdentity> apply(final AuthorizationCodeTokens tokens, final T
context.put(AuthorizationCodeTokens.class.getName(), tokens);
context.put(REFRESH_TOKEN_GRANT_RESPONSE, Boolean.TRUE);

final String idToken = decryptIdTokenIfEncryptedByProvider(configContext, tokens.getIdToken());

return authenticate(identityProviderManager, context,
new IdTokenCredential(tokens.getIdToken()))
new IdTokenCredential(idToken))
.call(new Function<SecurityIdentity, Uni<?>>() {
@Override
public Uni<Void> apply(SecurityIdentity identity) {
// after a successful refresh, rebuild the identity and update the cookie
return processSuccessfulAuthentication(context, configContext,
tokens, identity);
tokens, idToken, identity);
}
})
.map(new Function<SecurityIdentity, SecurityIdentity>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,26 @@ public class OidcProvider implements Closeable {
final OidcTenantConfig oidcConfig;
final String issuer;
final String[] audience;
final Key tokenDecryptionKey;

public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, JsonWebKeySet jwks) {
public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, JsonWebKeySet jwks, Key tokenDecryptionKey) {
this.client = client;
this.oidcConfig = oidcConfig;
this.asymmetricKeyResolver = jwks == null ? null
: new JsonWebKeyResolver(jwks, oidcConfig.token.forcedJwkRefreshInterval);

this.issuer = checkIssuerProp();
this.audience = checkAudienceProp();
this.tokenDecryptionKey = tokenDecryptionKey;
}

public OidcProvider(String publicKeyEnc, OidcTenantConfig oidcConfig) {
public OidcProvider(String publicKeyEnc, OidcTenantConfig oidcConfig, Key tokenDecryptionKey) {
this.client = null;
this.oidcConfig = oidcConfig;
this.asymmetricKeyResolver = new LocalPublicKeyResolver(publicKeyEnc);
this.issuer = checkIssuerProp();
this.audience = checkAudienceProp();
this.tokenDecryptionKey = tokenDecryptionKey;
}

private String checkIssuerProp() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,8 @@ private static OIDCException responseException(HttpResponse<Buffer> resp) {
public void close() {
client.close();
}

public Key getClientJwtKey() {
return clientJwtKey;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.quarkus.oidc.runtime;

import java.security.Key;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand All @@ -10,6 +12,8 @@
import java.util.function.Supplier;

import org.jboss.logging.Logger;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.PublicJsonWebKey;

import io.quarkus.arc.Arc;
import io.quarkus.oidc.OIDCException;
Expand All @@ -25,6 +29,8 @@
import io.quarkus.runtime.TlsConfig;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.runtime.configuration.ConfigurationException;
import io.smallrye.jwt.algorithm.KeyEncryptionAlgorithm;
import io.smallrye.jwt.util.KeyUtils;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Vertx;
import io.vertx.core.net.ProxyOptions;
Expand Down Expand Up @@ -141,7 +147,7 @@ private Uni<TenantConfigContext> createTenantContext(Vertx vertx, OidcTenantConf

if (!oidcConfig.tenantEnabled) {
LOG.debugf("'%s' tenant configuration is disabled", tenantId);
return Uni.createFrom().item(new TenantConfigContext(new OidcProvider(null, null, null), oidcConfig));
return Uni.createFrom().item(new TenantConfigContext(new OidcProvider(null, null, null, null), oidcConfig));
}

if (oidcConfig.getPublicKey().isPresent()) {
Expand Down Expand Up @@ -219,7 +225,8 @@ private static TenantConfigContext createTenantContextFromPublicKey(OidcTenantCo
LOG.debug("'public-key' property for the local token verification is set,"
+ " no connection to the OIDC server will be created");

return new TenantConfigContext(new OidcProvider(oidcConfig.publicKey.get(), oidcConfig), oidcConfig);
return new TenantConfigContext(
new OidcProvider(oidcConfig.publicKey.get(), oidcConfig, readTokenDecryptionKey(oidcConfig)), oidcConfig);
}

public void setSecurityEventObserved(boolean isSecurityEventObserved) {
Expand Down Expand Up @@ -248,17 +255,49 @@ public Uni<OidcProvider> apply(OidcProviderClient client) {

@Override
public OidcProvider apply(JsonWebKeySet jwks) {
return new OidcProvider(client, oidcConfig, jwks);
return new OidcProvider(client, oidcConfig, jwks,
readTokenDecryptionKey(oidcConfig));
}

});
} else {
return Uni.createFrom().item(new OidcProvider(client, oidcConfig, null));
return Uni.createFrom()
.item(new OidcProvider(client, oidcConfig, null, readTokenDecryptionKey(oidcConfig)));
}
}
});
}

private static Key readTokenDecryptionKey(OidcTenantConfig oidcConfig) {
if (oidcConfig.token.decryptionKeyLocation.isPresent()) {
try {
Key key = null;

String keyContent = KeyUtils.readKeyContent(oidcConfig.token.decryptionKeyLocation.get());
if (keyContent != null) {
List<JsonWebKey> keys = KeyUtils.loadJsonWebKeys(keyContent);
if (keys != null && keys.size() == 1 &&
(keys.get(0).getAlgorithm() == null
|| keys.get(0).getAlgorithm() == KeyEncryptionAlgorithm.RSA_OAEP.getAlgorithm())
&& ("enc".equals(keys.get(0).getUse()) || keys.get(0).getUse() == null)) {
key = PublicJsonWebKey.class.cast(keys.get(0)).getPrivateKey();
}
}
if (key == null) {
key = KeyUtils.decodeDecryptionPrivateKey(keyContent);
}
return key;
} catch (Exception ex) {
throw new ConfigurationException(
String.format("Token decryption key for tenant %s can not be read from %s",
oidcConfig.tenantId.get(), oidcConfig.token.decryptionKeyLocation.get()),
ex);
}
} else {
return null;
}
}

protected static Uni<JsonWebKeySet> getJsonWebSetUni(OidcProviderClient client, OidcTenantConfig oidcConfig) {
if (!oidcConfig.isDiscoveryEnabled().orElse(true)) {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkus.oidc.runtime;

import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
Expand All @@ -22,6 +23,7 @@
import org.jose4j.jwe.JsonWebEncryption;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.lang.JoseException;

import io.quarkus.oidc.AccessTokenCredential;
import io.quarkus.oidc.AuthorizationCodeTokens;
Expand Down Expand Up @@ -76,6 +78,10 @@ private OidcUtils() {

}

public static boolean isEncryptedToken(String token) {
return new StringTokenizer(token, ".").countTokens() == 5;
}

public static boolean isOpaqueToken(String token) {
return new StringTokenizer(token, ".").countTokens() != 3;
}
Expand Down Expand Up @@ -423,14 +429,18 @@ public static String encryptString(String jweString, SecretKey key) throws Excep
return jwe.getCompactSerialization();
}

public static JsonObject decryptJson(String jweString, SecretKey key) throws Exception {
public static JsonObject decryptJson(String jweString, Key key) throws Exception {
return new JsonObject(decryptString(jweString, key));
}

public static String decryptString(String jweString, SecretKey key) throws Exception {
public static String decryptString(String jweString, Key key) throws Exception {
return decryptString(jweString, key, KeyEncryptionAlgorithm.A256KW);
}

public static String decryptString(String jweString, Key key, KeyEncryptionAlgorithm algorithm) throws JoseException {
JsonWebEncryption jwe = new JsonWebEncryption();
jwe.setAlgorithmConstraints(new AlgorithmConstraints(AlgorithmConstraints.ConstraintType.PERMIT,
KeyEncryptionAlgorithm.A256KW.getAlgorithm()));
algorithm.getAlgorithm()));
jwe.setKey(key);
jwe.setCompactSerialization(jweString);
return jwe.getPlaintextString();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.quarkus.it.keycloak;

import javax.inject.Inject;
import javax.ws.rs.GET;
import javax.ws.rs.Path;

import org.eclipse.microprofile.jwt.JsonWebToken;

import io.quarkus.oidc.IdToken;
import io.quarkus.security.Authenticated;

@Path("/code-flow-encrypted-id-token")
public class CodeFlowEncryptedIdTokenResource {

@Inject
@IdToken
JsonWebToken idToken;

@GET
@Authenticated
@Path("/code-flow-encrypted-id-token-jwk")
public String accessJwk() {
return "user: " + idToken.getName();
}

@GET
@Authenticated
@Path("/code-flow-encrypted-id-token-pem")
public String accessPem() {
return "user: " + idToken.getName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ public String resolve(RoutingContext context) {
if (path.endsWith("code-flow") || path.endsWith("code-flow/logout")) {
return "code-flow";
}
if (path.endsWith("code-flow-encrypted-id-token-jwk")) {
return "code-flow-encrypted-id-token-jwk";
}
if (path.endsWith("code-flow-encrypted-id-token-pem")) {
return "code-flow-encrypted-id-token-pem";
}
if (path.endsWith("code-flow-form-post")) {
return "code-flow-form-post";
}
Expand Down
Loading

0 comments on commit 1d5e7bc

Please sign in to comment.