Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass the client and tenant id to OIDC request filters #39328

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
public class OidcClientImpl implements OidcClient {

private static final Logger LOG = Logger.getLogger(OidcClientImpl.class);

private static final String CLIENT_ID_ATTRIBUTE = "client-id";
private static final String DEFAULT_OIDC_CLIENT_ID = "Default";
private static final String AUTHORIZATION_HEADER = String.valueOf(HttpHeaders.AUTHORIZATION);

private final WebClient client;
Expand Down Expand Up @@ -279,7 +280,8 @@ private void checkClosed() {

private HttpRequest<Buffer> filter(OidcEndpoint.Type endpointType, HttpRequest<Buffer> request, Buffer body) {
if (!filters.isEmpty()) {
OidcRequestContextProperties props = new OidcRequestContextProperties();
OidcRequestContextProperties props = new OidcRequestContextProperties(
Map.of(CLIENT_ID_ATTRIBUTE, oidcConfig.getId().orElse(DEFAULT_OIDC_CLIENT_ID)));
for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) {
filter.filter(request, body, props);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.quarkus.oidc.client.OidcClients;
import io.quarkus.oidc.client.Tokens;
import io.quarkus.oidc.common.OidcEndpoint;
import io.quarkus.oidc.common.OidcRequestContextProperties;
import io.quarkus.oidc.common.OidcRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.oidc.common.runtime.OidcConstants;
Expand All @@ -35,6 +36,7 @@
public class OidcClientRecorder {

private static final Logger LOG = Logger.getLogger(OidcClientRecorder.class);
private static final String CLIENT_ID_ATTRIBUTE = "client-id";
private static final String DEFAULT_OIDC_CLIENT_ID = "Default";

public OidcClients setup(OidcClientsConfig oidcClientsConfig, TlsConfig tlsConfig, Supplier<Vertx> vertx) {
Expand Down Expand Up @@ -224,8 +226,10 @@ private static Uni<OidcConfigurationMetadata> discoverTokenUris(WebClient client
Map<OidcEndpoint.Type, List<OidcRequestFilter>> oidcRequestFilters,
String authServerUrl, OidcClientConfig oidcConfig, io.vertx.mutiny.core.Vertx vertx) {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
OidcRequestContextProperties contextProps = new OidcRequestContextProperties(
Map.of(CLIENT_ID_ATTRIBUTE, oidcConfig.getId().orElse(DEFAULT_OIDC_CLIENT_ID)));
return OidcCommonUtils
.discoverMetadata(client, oidcRequestFilters, authServerUrl, connectionDelayInMillisecs, vertx,
.discoverMetadata(client, oidcRequestFilters, contextProps, authServerUrl, connectionDelayInMillisecs, vertx,
oidcConfig.useBlockingDnsLookup)
.onItem().transform(json -> new OidcConfigurationMetadata(json.getString("token_endpoint"),
json.getString("revocation_endpoint")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,15 @@ public static Predicate<? super Throwable> oidcEndpointNotAvailable() {
}

public static Uni<JsonObject> discoverMetadata(WebClient client, Map<OidcEndpoint.Type, List<OidcRequestFilter>> filters,
String authServerUrl, long connectionDelayInMillisecs, Vertx vertx, boolean blockingDnsLookup) {
OidcRequestContextProperties contextProperties, String authServerUrl,
long connectionDelayInMillisecs, Vertx vertx, boolean blockingDnsLookup) {
final String discoveryUrl = getDiscoveryUri(authServerUrl);
HttpRequest<Buffer> request = client.getAbs(discoveryUrl);
if (!filters.isEmpty()) {
OidcRequestContextProperties requestProps = new OidcRequestContextProperties(
Map.of(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl));
Map<String, Object> newProperties = contextProperties == null ? new HashMap<>()
: new HashMap<>(contextProperties.getAll());
newProperties.put(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl);
OidcRequestContextProperties requestProps = new OidcRequestContextProperties(newProperties);
for (OidcRequestFilter filter : getMatchingOidcRequestFilters(filters, OidcEndpoint.Type.DISCOVERY)) {
filter.filter(request, null, requestProps);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
public class OidcProviderClient implements Closeable {
private static final Logger LOG = Logger.getLogger(OidcProviderClient.class);

private static final String TENANT_ID_ATTRIBUTE = "oidc-tenant-id";
private static final String AUTHORIZATION_HEADER = String.valueOf(HttpHeaders.AUTHORIZATION);
private static final String CONTENT_TYPE_HEADER = String.valueOf(HttpHeaders.CONTENT_TYPE);
private static final String ACCEPT_HEADER = String.valueOf(HttpHeaders.ACCEPT);
Expand Down Expand Up @@ -265,6 +266,7 @@ private HttpRequest<Buffer> filter(OidcEndpoint.Type endpointType, HttpRequest<B
if (!filters.isEmpty()) {
Map<String, Object> newProperties = contextProperties == null ? new HashMap<>()
: new HashMap<>(contextProperties.getAll());
newProperties.put(OidcUtils.TENANT_ID_ATTRIBUTE, oidcConfig.getTenantId().orElse(OidcUtils.DEFAULT_TENANT_ID));
newProperties.put(OidcConfigurationMetadata.class.getName(), metadata);
OidcRequestContextProperties newContextProperties = new OidcRequestContextProperties(newProperties);
for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.quarkus.oidc.TenantConfigResolver;
import io.quarkus.oidc.TenantIdentityProvider;
import io.quarkus.oidc.common.OidcEndpoint;
import io.quarkus.oidc.common.OidcRequestContextProperties;
import io.quarkus.oidc.common.OidcRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonConfig;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
Expand Down Expand Up @@ -487,8 +488,11 @@ protected static Uni<OidcProviderClient> createOidcClientUni(OidcTenantConfig oi
metadataUni = Uni.createFrom().item(createLocalMetadata(oidcConfig, authServerUriString));
} else {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
OidcRequestContextProperties contextProps = new OidcRequestContextProperties(
Map.of(OidcUtils.TENANT_ID_ATTRIBUTE, oidcConfig.getTenantId().orElse(OidcUtils.DEFAULT_TENANT_ID)));
metadataUni = OidcCommonUtils
.discoverMetadata(client, oidcRequestFilters, authServerUriString, connectionDelayInMillisecs, mutinyVertx,
.discoverMetadata(client, oidcRequestFilters, contextProps, authServerUriString, connectionDelayInMillisecs,
mutinyVertx,
oidcConfig.useBlockingDnsLookup)
.onItem()
.transform(new Function<JsonObject, OidcConfigurationMetadata>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class OidcRequestCustomizer implements OidcRequestFilter {
public void filter(HttpRequest<Buffer> request, Buffer buffer, OidcRequestContextProperties contextProps) {
String uri = request.uri();
if (uri.endsWith("/non-standard-tokens")) {
request.putHeader("client-id", contextProps.getString("client-id"));
request.putHeader("GrantType", getGrantType(buffer.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.it.keycloak;

import static com.github.tomakehurst.wiremock.client.WireMock.containing;
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;

Expand Down Expand Up @@ -54,6 +55,7 @@ public Map<String, String> start() {
server.stubFor(WireMock.post("/non-standard-tokens")
.withHeader("X-Custom", matching("XCustomHeaderValue"))
.withHeader("GrantType", matching("password"))
.withHeader("client-id", containing("non-standard-response"))
.withRequestBody(matching("grant_type=password&username=alice&password=alice&extra_param=extra_param_value"))
.willReturn(WireMock
.aResponse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.quarkus.oidc.common.OidcEndpoint.Type;
import io.quarkus.oidc.common.OidcRequestContextProperties;
import io.quarkus.oidc.common.OidcRequestFilter;
import io.quarkus.oidc.runtime.OidcUtils;
import io.vertx.mutiny.core.buffer.Buffer;
import io.vertx.mutiny.ext.web.client.HttpRequest;

Expand All @@ -23,6 +24,7 @@ public void filter(HttpRequest<Buffer> request, Buffer buffer, OidcRequestContex
throw new OIDCException("Filter is applied to the wrong endpoint: " + request.uri());
}
request.putHeader("Filter", "OK");
request.putHeader(OidcUtils.TENANT_ID_ATTRIBUTE, contextProps.getString(OidcUtils.TENANT_ID_ATTRIBUTE));
}

private boolean isJwksRequest(HttpRequest<Buffer> request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public void testAccessResourceAzure() throws Exception {
String azureJwk = readFile("jwks.json");
wireMockServer.stubFor(WireMock.get("/auth/azure/jwk")
.withHeader("Authorization", matching("Access token: " + azureToken))
.withHeader("Filter", matching("OK"))
.withHeader("tenant-id", matching("bearer-azure"))
.willReturn(WireMock.aResponse().withBody(azureJwk)));
RestAssured.given().auth().oauth2(azureToken)
.when().get("/api/admin/bearer-azure")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package io.quarkus.it.keycloak;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.absent;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.not;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;

Expand All @@ -26,6 +28,7 @@ public void start() {
server.stubFor(
get(urlEqualTo("/auth/realms/quarkus2/.well-known/openid-configuration"))
.withHeader("Filter", equalTo("OK"))
.withHeader("tenant-id", not(absent()))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("{\n" +
Expand All @@ -36,6 +39,7 @@ public void start() {
server.stubFor(
get(urlEqualTo("/auth/realms/quarkus2/protocol/openid-connect/certs"))
.withHeader("Filter", equalTo("OK"))
.withHeader("tenant-id", not(absent()))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("{\n" +
Expand Down
Loading