From 96e162c965814e4a400d407aa541d7d8202fbf75 Mon Sep 17 00:00:00 2001 From: Sergey Beryozkin Date: Mon, 11 Mar 2024 14:41:18 +0000 Subject: [PATCH] Pass the client and tenant id to OIDC request filters --- .../io/quarkus/oidc/client/runtime/OidcClientImpl.java | 6 ++++-- .../quarkus/oidc/client/runtime/OidcClientRecorder.java | 6 +++++- .../io/quarkus/oidc/common/runtime/OidcCommonUtils.java | 9 ++++++--- .../java/io/quarkus/oidc/runtime/OidcProviderClient.java | 2 ++ .../main/java/io/quarkus/oidc/runtime/OidcRecorder.java | 6 +++++- .../io/quarkus/it/keycloak/OidcRequestCustomizer.java | 1 + .../it/keycloak/KeycloakRealmResourceManager.java | 2 ++ .../it/keycloak/OidcDiscoveryJwksRequestCustomizer.java | 2 ++ .../it/keycloak/BearerTokenAuthorizationTest.java | 2 ++ .../io/quarkus/it/keycloak/WiremockTestResource.java | 4 ++++ 10 files changed, 33 insertions(+), 7 deletions(-) diff --git a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java index 3683eae39d305..665dba33a208e 100644 --- a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java +++ b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java @@ -37,7 +37,8 @@ public class OidcClientImpl implements OidcClient { private static final Logger LOG = Logger.getLogger(OidcClientImpl.class); - + private static final String CLIENT_ID_ATTRIBUTE = "client-id"; + private static final String DEFAULT_OIDC_CLIENT_ID = "Default"; private static final String AUTHORIZATION_HEADER = String.valueOf(HttpHeaders.AUTHORIZATION); private final WebClient client; @@ -279,7 +280,8 @@ private void checkClosed() { private HttpRequest filter(OidcEndpoint.Type endpointType, HttpRequest request, Buffer body) { if (!filters.isEmpty()) { - OidcRequestContextProperties props = new OidcRequestContextProperties(); + OidcRequestContextProperties props = new OidcRequestContextProperties( + Map.of(CLIENT_ID_ATTRIBUTE, oidcConfig.getId().orElse(DEFAULT_OIDC_CLIENT_ID))); for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) { filter.filter(request, body, props); } diff --git a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java index 101e88aab1ae8..23a3399214086 100644 --- a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java +++ b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java @@ -19,6 +19,7 @@ import io.quarkus.oidc.client.OidcClients; import io.quarkus.oidc.client.Tokens; import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonUtils; import io.quarkus.oidc.common.runtime.OidcConstants; @@ -35,6 +36,7 @@ public class OidcClientRecorder { private static final Logger LOG = Logger.getLogger(OidcClientRecorder.class); + private static final String CLIENT_ID_ATTRIBUTE = "client-id"; private static final String DEFAULT_OIDC_CLIENT_ID = "Default"; public OidcClients setup(OidcClientsConfig oidcClientsConfig, TlsConfig tlsConfig, Supplier vertx) { @@ -224,8 +226,10 @@ private static Uni discoverTokenUris(WebClient client Map> oidcRequestFilters, String authServerUrl, OidcClientConfig oidcConfig, io.vertx.mutiny.core.Vertx vertx) { final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig); + OidcRequestContextProperties contextProps = new OidcRequestContextProperties( + Map.of(CLIENT_ID_ATTRIBUTE, oidcConfig.getId().orElse(DEFAULT_OIDC_CLIENT_ID))); return OidcCommonUtils - .discoverMetadata(client, oidcRequestFilters, authServerUrl, connectionDelayInMillisecs, vertx, + .discoverMetadata(client, oidcRequestFilters, contextProps, authServerUrl, connectionDelayInMillisecs, vertx, oidcConfig.useBlockingDnsLookup) .onItem().transform(json -> new OidcConfigurationMetadata(json.getString("token_endpoint"), json.getString("revocation_endpoint"))); diff --git a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java index e4336d067be90..6997a29ec767c 100644 --- a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java +++ b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java @@ -439,12 +439,15 @@ public static Predicate oidcEndpointNotAvailable() { } public static Uni discoverMetadata(WebClient client, Map> filters, - String authServerUrl, long connectionDelayInMillisecs, Vertx vertx, boolean blockingDnsLookup) { + OidcRequestContextProperties contextProperties, String authServerUrl, + long connectionDelayInMillisecs, Vertx vertx, boolean blockingDnsLookup) { final String discoveryUrl = getDiscoveryUri(authServerUrl); HttpRequest request = client.getAbs(discoveryUrl); if (!filters.isEmpty()) { - OidcRequestContextProperties requestProps = new OidcRequestContextProperties( - Map.of(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl)); + Map newProperties = contextProperties == null ? new HashMap<>() + : new HashMap<>(contextProperties.getAll()); + newProperties.put(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl); + OidcRequestContextProperties requestProps = new OidcRequestContextProperties(newProperties); for (OidcRequestFilter filter : getMatchingOidcRequestFilters(filters, OidcEndpoint.Type.DISCOVERY)) { filter.filter(request, null, requestProps); } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java index 3090c67ade9f2..0879f649943bd 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java @@ -36,6 +36,7 @@ public class OidcProviderClient implements Closeable { private static final Logger LOG = Logger.getLogger(OidcProviderClient.class); + private static final String TENANT_ID_ATTRIBUTE = "oidc-tenant-id"; private static final String AUTHORIZATION_HEADER = String.valueOf(HttpHeaders.AUTHORIZATION); private static final String CONTENT_TYPE_HEADER = String.valueOf(HttpHeaders.CONTENT_TYPE); private static final String ACCEPT_HEADER = String.valueOf(HttpHeaders.ACCEPT); @@ -265,6 +266,7 @@ private HttpRequest filter(OidcEndpoint.Type endpointType, HttpRequest newProperties = contextProperties == null ? new HashMap<>() : new HashMap<>(contextProperties.getAll()); + newProperties.put(OidcUtils.TENANT_ID_ATTRIBUTE, oidcConfig.getTenantId().orElse(OidcUtils.DEFAULT_TENANT_ID)); newProperties.put(OidcConfigurationMetadata.class.getName(), metadata); OidcRequestContextProperties newContextProperties = new OidcRequestContextProperties(newProperties); for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) { diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java index fd11423fdd54e..ee86e20ccb967 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java @@ -36,6 +36,7 @@ import io.quarkus.oidc.TenantConfigResolver; import io.quarkus.oidc.TenantIdentityProvider; import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonConfig; import io.quarkus.oidc.common.runtime.OidcCommonUtils; @@ -487,8 +488,11 @@ protected static Uni createOidcClientUni(OidcTenantConfig oi metadataUni = Uni.createFrom().item(createLocalMetadata(oidcConfig, authServerUriString)); } else { final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig); + OidcRequestContextProperties contextProps = new OidcRequestContextProperties( + Map.of(OidcUtils.TENANT_ID_ATTRIBUTE, oidcConfig.getTenantId().orElse(OidcUtils.DEFAULT_TENANT_ID))); metadataUni = OidcCommonUtils - .discoverMetadata(client, oidcRequestFilters, authServerUriString, connectionDelayInMillisecs, mutinyVertx, + .discoverMetadata(client, oidcRequestFilters, contextProps, authServerUriString, connectionDelayInMillisecs, + mutinyVertx, oidcConfig.useBlockingDnsLookup) .onItem() .transform(new Function() { diff --git a/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java b/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java index 513b1584b3a91..14b439b510934 100644 --- a/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java +++ b/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java @@ -19,6 +19,7 @@ public class OidcRequestCustomizer implements OidcRequestFilter { public void filter(HttpRequest request, Buffer buffer, OidcRequestContextProperties contextProps) { String uri = request.uri(); if (uri.endsWith("/non-standard-tokens")) { + request.putHeader("client-id", contextProps.getString("client-id")); request.putHeader("GrantType", getGrantType(buffer.toString())); } } diff --git a/integration-tests/oidc-client-wiremock/src/test/java/io/quarkus/it/keycloak/KeycloakRealmResourceManager.java b/integration-tests/oidc-client-wiremock/src/test/java/io/quarkus/it/keycloak/KeycloakRealmResourceManager.java index 1d7d2927fbf25..627e365ed55ef 100644 --- a/integration-tests/oidc-client-wiremock/src/test/java/io/quarkus/it/keycloak/KeycloakRealmResourceManager.java +++ b/integration-tests/oidc-client-wiremock/src/test/java/io/quarkus/it/keycloak/KeycloakRealmResourceManager.java @@ -1,5 +1,6 @@ package io.quarkus.it.keycloak; +import static com.github.tomakehurst.wiremock.client.WireMock.containing; import static com.github.tomakehurst.wiremock.client.WireMock.matching; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; @@ -54,6 +55,7 @@ public Map start() { server.stubFor(WireMock.post("/non-standard-tokens") .withHeader("X-Custom", matching("XCustomHeaderValue")) .withHeader("GrantType", matching("password")) + .withHeader("client-id", containing("non-standard-response")) .withRequestBody(matching("grant_type=password&username=alice&password=alice&extra_param=extra_param_value")) .willReturn(WireMock .aResponse() diff --git a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryJwksRequestCustomizer.java b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryJwksRequestCustomizer.java index 80cc1d4580cee..d5c1464191ca3 100644 --- a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryJwksRequestCustomizer.java +++ b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryJwksRequestCustomizer.java @@ -8,6 +8,7 @@ import io.quarkus.oidc.common.OidcEndpoint.Type; import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; +import io.quarkus.oidc.runtime.OidcUtils; import io.vertx.mutiny.core.buffer.Buffer; import io.vertx.mutiny.ext.web.client.HttpRequest; @@ -23,6 +24,7 @@ public void filter(HttpRequest request, Buffer buffer, OidcRequestContex throw new OIDCException("Filter is applied to the wrong endpoint: " + request.uri()); } request.putHeader("Filter", "OK"); + request.putHeader(OidcUtils.TENANT_ID_ATTRIBUTE, contextProps.getString(OidcUtils.TENANT_ID_ATTRIBUTE)); } private boolean isJwksRequest(HttpRequest request) { diff --git a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java index bee2e7e42f83f..d95361d301e6c 100644 --- a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java +++ b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/BearerTokenAuthorizationTest.java @@ -64,6 +64,8 @@ public void testAccessResourceAzure() throws Exception { String azureJwk = readFile("jwks.json"); wireMockServer.stubFor(WireMock.get("/auth/azure/jwk") .withHeader("Authorization", matching("Access token: " + azureToken)) + .withHeader("Filter", matching("OK")) + .withHeader("tenant-id", matching("bearer-azure")) .willReturn(WireMock.aResponse().withBody(azureJwk))); RestAssured.given().auth().oauth2(azureToken) .when().get("/api/admin/bearer-azure") diff --git a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java index 5f9d4898dccee..0ea4f558370de 100644 --- a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java +++ b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java @@ -1,8 +1,10 @@ package io.quarkus.it.keycloak; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.absent; import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.not; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; @@ -26,6 +28,7 @@ public void start() { server.stubFor( get(urlEqualTo("/auth/realms/quarkus2/.well-known/openid-configuration")) .withHeader("Filter", equalTo("OK")) + .withHeader("tenant-id", not(absent())) .willReturn(aResponse() .withHeader("Content-Type", "application/json") .withBody("{\n" + @@ -36,6 +39,7 @@ public void start() { server.stubFor( get(urlEqualTo("/auth/realms/quarkus2/protocol/openid-connect/certs")) .withHeader("Filter", equalTo("OK")) + .withHeader("tenant-id", not(absent())) .willReturn(aResponse() .withHeader("Content-Type", "application/json") .withBody("{\n" +