Skip to content

Commit

Permalink
Merge pull request #34450 from sberyozkin/fix_backchannel_logout
Browse files Browse the repository at this point in the history
Support multiple backchannel logout tokens
  • Loading branch information
sberyozkin authored Jul 3, 2023
2 parents 8bcd0a9 + c0e59b5 commit ac62790
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,73 @@ public static class Backchannel {
@ConfigItem
public Optional<String> path = Optional.empty();

/**
* Maximum number of logout tokens that can be cached before they are matched against ID tokens stored in session
* cookies.
*/
@ConfigItem(defaultValue = "10")
public int tokenCacheSize = 10;

/**
* Number of minutes a logout token can be cached for.
*/
@ConfigItem(defaultValue = "10M")
public Duration tokenCacheTimeToLive = Duration.ofMinutes(10);

/**
* Token cache timer interval.
* If this property is set then a timer will check and remove the stale entries periodically.
*/
@ConfigItem
public Optional<Duration> cleanUpTimerInterval = Optional.empty();

/**
* Logout token claim whose value will be used as a key for caching the tokens.
* Only `sub` (subject) and `sid` (session id) claims can be used as keys.
* Set it to `sid` only if ID tokens issued by the OIDC provider have no `sub` but have `sid` claim.
*/
@ConfigItem(defaultValue = "sub")
public String logoutTokenKey = "sub";

public void setPath(Optional<String> path) {
this.path = path;
}

public Optional<String> getPath() {
return path;
}

public String getLogoutTokenKey() {
return logoutTokenKey;
}

public void setLogoutTokenKey(String logoutTokenKey) {
this.logoutTokenKey = logoutTokenKey;
}

public int getTokenCacheSize() {
return tokenCacheSize;
}

public void setTokenCacheSize(int tokenCacheSize) {
this.tokenCacheSize = tokenCacheSize;
}

public Duration getTokenCacheTimeToLive() {
return tokenCacheTimeToLive;
}

public void setTokenCacheTimeToLive(Duration tokenCacheTimeToLive) {
this.tokenCacheTimeToLive = tokenCacheTimeToLive;
}

public Optional<Duration> getCleanUpTimerInterval() {
return cleanUpTimerInterval;
}

public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) {
this.cleanUpTimerInterval = Optional.of(cleanUpTimerInterval);
}
}

@ConfigGroup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,17 @@ public void accept(MultiMap form) {
.verifyLogoutJwtToken(encodedLogoutToken);

if (verifyLogoutTokenClaims(result)) {
resolver.getBackChannelLogoutTokens().put(oidcTenantConfig.tenantId.get(),
result);
String key = result.localVerificationResult
.getString(oidcTenantConfig.logout.backchannel.logoutTokenKey);
BackChannelLogoutTokenCache tokens = resolver
.getBackChannelLogoutTokens().get(oidcTenantConfig.tenantId.get());
if (tokens == null) {
tokens = new BackChannelLogoutTokenCache(oidcTenantConfig, context.vertx());
resolver.getBackChannelLogoutTokens().put(oidcTenantConfig.tenantId.get(),
tokens);
}
tokens.addTokenVerification(key, result);

if (resolver.isSecurityEventObserved()) {
resolver.getSecurityEvent().fire(
new SecurityEvent(Type.OIDC_BACKCHANNEL_LOGOUT_INITIATED,
Expand Down Expand Up @@ -122,15 +131,15 @@ private boolean verifyLogoutTokenClaims(TokenVerificationResult result) {
LOG.debug("Back channel logout token does not have a valid 'events' claim");
return false;
}
if (!result.localVerificationResult.containsKey(Claims.sub.name())
&& !result.localVerificationResult.containsKey(OidcConstants.BACK_CHANNEL_LOGOUT_SID_CLAIM)) {
LOG.debug("Back channel logout token does not have 'sub' or 'sid' claim");
if (!result.localVerificationResult.containsKey(oidcTenantConfig.logout.backchannel.logoutTokenKey)) {
LOG.debugf("Back channel logout token does not have %s", oidcTenantConfig.logout.backchannel.logoutTokenKey);
return false;
}
if (result.localVerificationResult.containsKey(Claims.nonce.name())) {
LOG.debug("Back channel logout token must not contain 'nonce' claim");
return false;
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package io.quarkus.oidc.runtime;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import io.quarkus.oidc.OidcTenantConfig;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;

public class BackChannelLogoutTokenCache {
private OidcTenantConfig oidcConfig;

private Map<String, CacheEntry> cacheMap = new ConcurrentHashMap<>();;
private AtomicInteger size = new AtomicInteger();

public BackChannelLogoutTokenCache(OidcTenantConfig oidcTenantConfig, Vertx vertx) {
this.oidcConfig = oidcTenantConfig;
init(vertx);
}

private void init(Vertx vertx) {
cacheMap = new ConcurrentHashMap<>();
if (oidcConfig.logout.backchannel.cleanUpTimerInterval.isPresent()) {
vertx.setPeriodic(oidcConfig.logout.backchannel.cleanUpTimerInterval.get().toMillis(), new Handler<Long>() {
@Override
public void handle(Long event) {
// Remove all the entries which have expired
removeInvalidEntries();
}
});
}
}

public void addTokenVerification(String token, TokenVerificationResult result) {
if (!prepareSpaceForNewCacheEntry()) {
clearCache();
}
cacheMap.put(token, new CacheEntry(result));
}

public TokenVerificationResult removeTokenVerification(String token) {
CacheEntry entry = removeCacheEntry(token);
return entry == null ? null : entry.result;
}

public void clearCache() {
cacheMap.clear();
size.set(0);
}

private void removeInvalidEntries() {
long now = now();
for (Iterator<Map.Entry<String, CacheEntry>> it = cacheMap.entrySet().iterator(); it.hasNext();) {
Map.Entry<String, CacheEntry> next = it.next();
if (isEntryExpired(next.getValue(), now)) {
it.remove();
size.decrementAndGet();
}
}
}

private boolean prepareSpaceForNewCacheEntry() {
int currentSize;
do {
currentSize = size.get();
if (currentSize == oidcConfig.logout.backchannel.tokenCacheSize) {
return false;
}
} while (!size.compareAndSet(currentSize, currentSize + 1));
return true;
}

private CacheEntry removeCacheEntry(String token) {
CacheEntry entry = cacheMap.remove(token);
if (entry != null) {
size.decrementAndGet();
}
return entry;
}

private boolean isEntryExpired(CacheEntry entry, long now) {
return entry.createdTime + oidcConfig.logout.backchannel.tokenCacheTimeToLive.toMillis() < now;
}

private static long now() {
return System.currentTimeMillis();
}

private static class CacheEntry {
volatile TokenVerificationResult result;
long createdTime = System.currentTimeMillis();

public CacheEntry(TokenVerificationResult result) {
this.result = result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,19 @@ private static String decryptIdTokenIfEncryptedByProvider(TenantConfigContext re
}

private boolean isBackChannelLogoutPendingAndValid(TenantConfigContext configContext, SecurityIdentity identity) {
TokenVerificationResult backChannelLogoutTokenResult = resolver.getBackChannelLogoutTokens()
.remove(configContext.oidcConfig.getTenantId().get());
if (backChannelLogoutTokenResult != null) {

BackChannelLogoutTokenCache tokens = resolver.getBackChannelLogoutTokens()
.get(configContext.oidcConfig.getTenantId().get());
if (tokens != null) {
JsonObject idTokenJson = OidcUtils.decodeJwtContent(((JsonWebToken) (identity.getPrincipal())).getRawToken());

String logoutTokenKeyValue = idTokenJson.getString(configContext.oidcConfig.logout.backchannel.getLogoutTokenKey());

TokenVerificationResult backChannelLogoutTokenResult = tokens.removeTokenVerification(logoutTokenKeyValue);
if (backChannelLogoutTokenResult == null) {
return false;
}

String idTokenIss = idTokenJson.getString(Claims.iss.name());
String logoutTokenIss = backChannelLogoutTokenResult.localVerificationResult.getString(Claims.iss.name());
if (logoutTokenIss != null && !logoutTokenIss.equals(idTokenIss)) {
Expand All @@ -412,7 +421,7 @@ private boolean isBackChannelLogoutPendingAndValid(TenantConfigContext configCon
LOG.debugf("Logout token session id does not match the ID token session id");
return false;
}
LOG.debugf("Frontchannel logout request for the tenant %s has been completed",
LOG.debugf("Backchannel logout request for the tenant %s has been completed",
configContext.oidcConfig.tenantId.get());

fireEvent(SecurityEvent.Type.OIDC_BACKCHANNEL_LOGOUT_COMPLETED, identity);
Expand Down Expand Up @@ -495,7 +504,7 @@ public Uni<ChallengeData> apply(Void t) {

if (context.get(NO_OIDC_COOKIES_AVAILABLE) != null
&& isRedirectFromProvider(context, configContext)) {
LOG.debug(
LOG.warn(
"The state cookie is missing after the redirect from OpenId Connect Provider, authentication has failed");
return Uni.createFrom().item(new ChallengeData(401, "WWW-Authenticate", "OIDC"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class DefaultTenantConfigResolver {

private volatile boolean securityEventObserved;

private ConcurrentHashMap<String, TokenVerificationResult> backChannelLogoutTokens = new ConcurrentHashMap<>();
private ConcurrentHashMap<String, BackChannelLogoutTokenCache> backChannelLogoutTokens = new ConcurrentHashMap<>();

@PostConstruct
public void verifyResolvers() {
Expand Down Expand Up @@ -232,7 +232,7 @@ boolean isEnableHttpForwardedPrefix() {
return enableHttpForwardedPrefix;
}

public Map<String, TokenVerificationResult> getBackChannelLogoutTokens() {
public Map<String, BackChannelLogoutTokenCache> getBackChannelLogoutTokens() {
return backChannelLogoutTokens;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntUnaryOperator;

import io.quarkus.oidc.OidcRequestContext;
import io.quarkus.oidc.OidcTenantConfig;
Expand Down Expand Up @@ -169,18 +168,4 @@ public CacheEntry(UserInfo userInfo) {
this.userInfo = userInfo;
}
}

private static class IncrementOperator implements IntUnaryOperator {
int maxSize;

IncrementOperator(int maxSize) {
this.maxSize = maxSize;
}

@Override
public int applyAsInt(int n) {
return n < maxSize ? n + 1 : n;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,25 @@ public void testCodeFlowFormPostAndBackChannelLogout() throws IOException {
// Session is still active
assertNotNull(getSessionCookie(webClient, "code-flow-form-post"));

// request a back channel logout
// ID token subject is `123456`
// request a back channel logout for some other subject
RestAssured.given()
.when().contentType(ContentType.URLENC).body("logout_token=" + OidcWiremockTestResource.getLogoutToken())
.when().contentType(ContentType.URLENC)
.body("logout_token=" + OidcWiremockTestResource.getLogoutToken("789"))
.post("/back-channel-logout")
.then()
.statusCode(200);

// No logout:
page = webClient.getPage("http://localhost:8081/code-flow-form-post");
assertEquals("alice", page.getBody().asNormalizedText());
// Session is still active
assertNotNull(getSessionCookie(webClient, "code-flow-form-post"));

// request a back channel logout for the same subject
RestAssured.given()
.when().contentType(ContentType.URLENC).body("logout_token="
+ OidcWiremockTestResource.getLogoutToken("123456"))
.post("/back-channel-logout")
.then()
.statusCode(200);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,17 @@ public static String getLogoutToken() {
.sign("privateKey.jwk");
}

public static String getLogoutToken(String sub) {
return Jwt.issuer(TOKEN_ISSUER)
.audience(TOKEN_AUDIENCE)
.subject(sub)
.claim("events", createEventsClaim())
.claim("sid", "session-id")
.jws()
.keyId("1")
.sign("privateKey.jwk");
}

private static JsonObject createEventsClaim() {
return Json.createObjectBuilder().add("http://schemas.openid.net/event/backchannel-logout",
Json.createObjectBuilder().build()).build();
Expand Down

0 comments on commit ac62790

Please sign in to comment.