Skip to content

Commit

Permalink
Add Redirect annotation for OidcRedirectFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin committed May 23, 2024
1 parent a18c9d7 commit dc82a98
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 34 deletions.
23 changes: 13 additions & 10 deletions docs/src/main/asciidoc/security-oidc-code-flow-authentication.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -475,34 +475,37 @@ import org.eclipse.microprofile.jwt.Claims;
import io.quarkus.arc.Unremovable;
import io.quarkus.oidc.AuthorizationCodeTokens;
import io.quarkus.oidc.OidcRedirectFilter;
import io.quarkus.oidc.Redirect;
import io.quarkus.oidc.TenantFeature;
import io.quarkus.oidc.runtime.OidcUtils;
import io.smallrye.jwt.build.Jwt;
@ApplicationScoped
@Unremovable
@TenantFeature("tenant-refresh")
@Redirect(Location.SESSION_EXPIRED_PAGE) <1>
public class SessionExpiredOidcRedirectFilter implements OidcRedirectFilter {
@Override
public void filter(OidcRedirectContext context) {
if (context.redirectUri().contains("/session-expired-page")) {
AuthorizationCodeTokens tokens = context.routingContext().get(AuthorizationCodeTokens.class.getName()); <1>
String userName = OidcUtils.decodeJwtContent(tokens.getIdToken()).getString(Claims.preferred_username.name()); <2>
String jwe = Jwt.preferredUserName(userName).jwe()
.encryptWithSecret(context.oidcTenantConfig().credentials.secret.get()); <3>
OidcUtils.createCookie(context.routingContext(), context.oidcTenantConfig(), "session_expired",
jwe + "|" + context.oidcTenantConfig().tenantId.get(), 10); <4>
AuthorizationCodeTokens tokens = context.routingContext().get(AuthorizationCodeTokens.class.getName()); <2>
String userName = OidcUtils.decodeJwtContent(tokens.getIdToken()).getString(Claims.preferred_username.name()); <3>
String jwe = Jwt.preferredUserName(userName).jwe()
.encryptWithSecret(context.oidcTenantConfig().credentials.secret.get()); <4>
OidcUtils.createCookie(context.routingContext(), context.oidcTenantConfig(), "session_expired",
jwe + "|" + context.oidcTenantConfig().tenantId.get(), 10); <5>
}
}
}
----
<1> Access `AuthorizationCodeTokens` tokens associated with the now expired session as a `RoutingContext` attribute.
<2> Decode ID token claims and get a user name.
<3> Save the user name in a JWT token encrypted with the current OIDC tenant's client secret.
<4> Create a custom `session_expired` cookie valid for 5 seconds which joins the encrypted token and a tenant id using a "|" separator. Recording a tenant id in a custom cookie can help to generate correct session expired pages in a multi-tenant OIDC setup.
<1> Make sure this redirect filter is only called during a redirect to the session expired page.
<2> Access `AuthorizationCodeTokens` tokens associated with the now expired session as a `RoutingContext` attribute.
<3> Decode ID token claims and get a user name.
<4> Save the user name in a JWT token encrypted with the current OIDC tenant's client secret.
<5> Create a custom `session_expired` cookie valid for 5 seconds which joins the encrypted token and a tenant id using a "|" separator. Recording a tenant id in a custom cookie can help to generate correct session expired pages in a multi-tenant OIDC setup.

Next, a public JAX-RS resource which generates session expired pages can use this cookie to create a page tailored for this user and the corresponding OIDC tenant, for example:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,6 @@ public static List<OidcRequestFilter> getMatchingOidcRequestFilters(Map<OidcEndp
combined.addAll(all);
return combined;
}

}

public static Uni<HttpResponse<Buffer>> sendRequest(io.vertx.core.Vertx vertx, HttpRequest<Buffer> request,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.quarkus.oidc;

import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;

/**
* Annotation that can be used to restrict {@link OidcRedirectFilter} to specific redirect locations
*/
@Target({ TYPE })
@Retention(RUNTIME)
public @interface Redirect {

enum Location {
ALL,

/**
* Applies to OIDC authorization endpoint
*/
OIDC_AUTHORIZATION,

/**
* Applies to OIDC logout endpoint
*/
OIDC_LOGOUT,

/**
* Applies to the local redirect to a custom error page resource when an authorization code flow
* redirect from OIDC provider to Quarkus returns an error instead of an authorization code
*/
ERROR_PAGE,

/**
* Applies to the local redirect to a custom session expired page resource when
* the current user's session has expired and no longer can be refreshed.
*/
SESSION_EXPIRED_PAGE,

/**
* Applies to the local redirect to the callback resource which is done after successful authorization
* code flow completion in order to drop the code and state parameters from the callback URL.
*/
LOCAL_ENDPOINT_CALLBACK
}

/**
* Identifies one or more redirect locations.
*/
Location[] value() default Location.ALL;
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.quarkus.oidc.OidcTenantConfig;
import io.quarkus.oidc.OidcTenantConfig.Authentication;
import io.quarkus.oidc.OidcTenantConfig.Authentication.ResponseMode;
import io.quarkus.oidc.Redirect;
import io.quarkus.oidc.SecurityEvent;
import io.quarkus.oidc.UserInfo;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
Expand Down Expand Up @@ -230,7 +231,7 @@ public Uni<SecurityIdentity> apply(TenantConfigContext tenantContext) {
String finalErrorUri = errorUri.toString();
LOG.debugf("Error URI: %s", finalErrorUri);
return Uni.createFrom().failure(new AuthenticationRedirectException(
filterRedirect(context, tenantContext, finalErrorUri)));
filterRedirect(context, tenantContext, finalErrorUri, Redirect.Location.ERROR_PAGE)));
}

});
Expand All @@ -247,11 +248,12 @@ public Uni<SecurityIdentity> apply(TenantConfigContext tenantContext) {
}

private static String filterRedirect(RoutingContext context,
TenantConfigContext tenantContext, String redirectUri) {
if (!tenantContext.getOidcRedirectFilters().isEmpty()) {
TenantConfigContext tenantContext, String redirectUri, Redirect.Location location) {
List<OidcRedirectFilter> redirectFilters = tenantContext.getOidcRedirectFilters(location);
if (!redirectFilters.isEmpty()) {
OidcRedirectContext redirectContext = new OidcRedirectContext(context, tenantContext.getOidcTenantConfig(),
redirectUri, MultiMap.caseInsensitiveMultiMap());
for (OidcRedirectFilter filter : tenantContext.getOidcRedirectFilters()) {
for (OidcRedirectFilter filter : redirectFilters) {
filter.filter(redirectContext);
}
MultiMap queries = redirectContext.additionalQueryParams();
Expand Down Expand Up @@ -455,7 +457,7 @@ private Uni<SecurityIdentity> redirectToSessionExpiredPage(RoutingContext contex
LOG.debugf("Session Expired URI: %s", sessionExpiredUri);
return removeSessionCookie(context, configContext.oidcConfig)
.chain(() -> Uni.createFrom().failure(new AuthenticationRedirectException(
filterRedirect(context, configContext, sessionExpiredUri))));
filterRedirect(context, configContext, sessionExpiredUri, Redirect.Location.SESSION_EXPIRED_PAGE))));
}

private static String decryptIdTokenIfEncryptedByProvider(TenantConfigContext resolvedContext, String token) {
Expand Down Expand Up @@ -715,7 +717,8 @@ && isRedirectFromProvider(context, configContext)) {
String authorizationURL = configContext.provider.getMetadata().getAuthorizationUri() + "?"
+ codeFlowParams.toString();

authorizationURL = filterRedirect(context, configContext, authorizationURL);
authorizationURL = filterRedirect(context, configContext, authorizationURL,
Redirect.Location.OIDC_AUTHORIZATION);
LOG.debugf("Code flow redirect to: %s", authorizationURL);

return Uni.createFrom().item(new ChallengeData(HttpResponseStatus.FOUND.code(), HttpHeaders.LOCATION,
Expand Down Expand Up @@ -873,7 +876,8 @@ public SecurityIdentity apply(SecurityIdentity identity) {
LOG.debugf("Removing code flow redirect parameters, final redirect URI: %s",
finalRedirectUri);
throw new AuthenticationRedirectException(
filterRedirect(context, configContext, finalRedirectUri));
filterRedirect(context, configContext, finalRedirectUri,
Redirect.Location.LOCAL_ENDPOINT_CALLBACK));
} else {
return identity;
}
Expand Down Expand Up @@ -1384,7 +1388,8 @@ private Uni<Void> buildLogoutRedirectUriUni(RoutingContext context, TenantConfig
public Void apply(Void t) {
String logoutUri = buildLogoutRedirectUri(configContext, idToken, context);
LOG.debugf("Logout uri: %s", logoutUri);
throw new AuthenticationRedirectException(filterRedirect(context, configContext, logoutUri));
throw new AuthenticationRedirectException(
filterRedirect(context, configContext, logoutUri, Redirect.Location.OIDC_LOGOUT));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package io.quarkus.oidc.runtime;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

import org.jboss.logging.Logger;

import io.quarkus.arc.ClientProxy;
import io.quarkus.oidc.OIDCException;
import io.quarkus.oidc.OidcConfigurationMetadata;
import io.quarkus.oidc.OidcRedirectFilter;
import io.quarkus.oidc.OidcTenantConfig;
import io.quarkus.oidc.Redirect;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.runtime.configuration.ConfigurationException;

Expand All @@ -29,7 +34,7 @@ public class TenantConfigContext {
*/
final OidcTenantConfig oidcConfig;

final List<OidcRedirectFilter> redirectFilters;
final Map<Redirect.Location, List<OidcRedirectFilter>> redirectFilters;

/**
* PKCE Secret Key
Expand All @@ -50,7 +55,7 @@ public TenantConfigContext(OidcProvider client, OidcTenantConfig config) {
public TenantConfigContext(OidcProvider client, OidcTenantConfig config, boolean ready) {
this.provider = client;
this.oidcConfig = config;
this.redirectFilters = TenantFeatureFinder.find(config, OidcRedirectFilter.class);
this.redirectFilters = getRedirectFiltersMap(TenantFeatureFinder.find(config, OidcRedirectFilter.class));
this.ready = ready;

boolean isService = OidcUtils.isServiceApp(config);
Expand Down Expand Up @@ -164,10 +169,6 @@ public OidcTenantConfig getOidcTenantConfig() {
return oidcConfig;
}

public List<OidcRedirectFilter> getOidcRedirectFilters() {
return redirectFilters;
}

public OidcConfigurationMetadata getOidcMetadata() {
return provider != null ? provider.getMetadata() : null;
}
Expand All @@ -183,4 +184,37 @@ public SecretKey getStateEncryptionKey() {
public SecretKey getTokenEncSecretKey() {
return tokenEncSecretKey;
}

private static Map<Redirect.Location, List<OidcRedirectFilter>> getRedirectFiltersMap(List<OidcRedirectFilter> filters) {
Map<Redirect.Location, List<OidcRedirectFilter>> map = new HashMap<>();
for (OidcRedirectFilter filter : filters) {
Redirect redirect = ClientProxy.unwrap(filter).getClass().getAnnotation(Redirect.class);
if (redirect != null) {
for (Redirect.Location loc : redirect.value()) {
map.computeIfAbsent(loc, k -> new ArrayList<OidcRedirectFilter>()).add(filter);
}
} else {
map.computeIfAbsent(Redirect.Location.ALL, k -> new ArrayList<OidcRedirectFilter>()).add(filter);
}
}
return map;
}

List<OidcRedirectFilter> getOidcRedirectFilters(Redirect.Location loc) {
List<OidcRedirectFilter> typeSpecific = redirectFilters.get(loc);
List<OidcRedirectFilter> all = redirectFilters.get(Redirect.Location.ALL);
if (typeSpecific == null && all == null) {
return List.of();
}
if (typeSpecific != null && all == null) {
return typeSpecific;
} else if (typeSpecific == null && all != null) {
return all;
} else {
List<OidcRedirectFilter> combined = new ArrayList<>(typeSpecific.size() + all.size());
combined.addAll(typeSpecific);
combined.addAll(all);
return combined;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import io.quarkus.arc.Unremovable;
import io.quarkus.oidc.AuthorizationCodeTokens;
import io.quarkus.oidc.OidcRedirectFilter;
import io.quarkus.oidc.Redirect;
import io.quarkus.oidc.Redirect.Location;
import io.quarkus.oidc.TenantFeature;
import io.quarkus.oidc.runtime.OidcUtils;
import io.smallrye.jwt.build.Jwt;

@ApplicationScoped
@Unremovable
@TenantFeature("tenant-refresh")
@Redirect(Location.SESSION_EXPIRED_PAGE)
public class SessionExpiredOidcRedirectFilter implements OidcRedirectFilter {

@Override
Expand All @@ -23,16 +26,18 @@ public void filter(OidcRedirectContext context) {
throw new RuntimeException("Invalid tenant id");
}

if (context.redirectUri().contains("/session-expired-page")) {
AuthorizationCodeTokens tokens = context.routingContext().get(AuthorizationCodeTokens.class.getName());
String userName = OidcUtils.decodeJwtContent(tokens.getIdToken()).getString(Claims.preferred_username.name());
String jwe = Jwt.preferredUserName(userName).jwe()
.encryptWithSecret(context.oidcTenantConfig().credentials.secret.get());
OidcUtils.createCookie(context.routingContext(), context.oidcTenantConfig(), "session_expired",
jwe + "|" + context.oidcTenantConfig().tenantId.get(), 10);

context.additionalQueryParams().add("session-expired", "true");
if (!context.redirectUri().contains("/session-expired-page")) {
throw new RuntimeException("Invalid redirect URI");
}

AuthorizationCodeTokens tokens = context.routingContext().get(AuthorizationCodeTokens.class.getName());
String userName = OidcUtils.decodeJwtContent(tokens.getIdToken()).getString(Claims.preferred_username.name());
String jwe = Jwt.preferredUserName(userName).jwe()
.encryptWithSecret(context.oidcTenantConfig().credentials.secret.get());
OidcUtils.createCookie(context.routingContext(), context.oidcTenantConfig(), "session_expired",
jwe + "|" + context.oidcTenantConfig().tenantId.get(), 10);

context.additionalQueryParams().add("session-expired", "true");
}

}

0 comments on commit dc82a98

Please sign in to comment.