Skip to content

Commit

Permalink
Added support for multi-tenant authentication for Key Vault clients. (#…
Browse files Browse the repository at this point in the history
…25300)

* Added support for multi-tenant authentication on Key Vault libraries.

* Updated CHANGELOG.

* Removed RuntimeException being thrown from KeyVaultCredentialPolicy.

* Added tests and fixed an issue that grabbed the wrong segment when parsing an authorization URI.

* Fixed test issues.

* Fixed test issues for good.
  • Loading branch information
vcolin7 authored Nov 12, 2021
1 parent b03f3da commit 60f92e8
Show file tree
Hide file tree
Showing 21 changed files with 1,001 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added support for multi-tenant authentication in clients.

### Breaking Changes

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.Collections;
Expand All @@ -34,8 +36,8 @@ public class KeyVaultCredentialPolicy extends BearerTokenAuthenticationPolicy {
private static final String KEY_VAULT_STASHED_CONTENT_KEY = "KeyVaultCredentialPolicyStashedBody";
private static final String KEY_VAULT_STASHED_CONTENT_LENGTH_KEY = "KeyVaultCredentialPolicyStashedContentLength";
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
private static final ConcurrentMap<String, String> SCOPE_CACHE = new ConcurrentHashMap<>();
private String scope;
private static final ConcurrentMap<String, ChallengeParameters> CHALLENGE_CACHE = new ConcurrentHashMap<>();
private ChallengeParameters challenge;

/**
* Creates a {@link KeyVaultCredentialPolicy}.
Expand Down Expand Up @@ -80,6 +82,7 @@ private static Map<String, String> extractChallengeAttributes(String authenticat
*
* @param authenticateHeader The authentication header containing all the challenges.
* @param authChallengePrefix The authentication challenge name.
*
* @return A boolean indicating if the challenge is a bearer challenge or not.
*/
private static boolean isBearerChallenge(String authenticateHeader, String authChallengePrefix) {
Expand All @@ -92,15 +95,17 @@ public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {
return Mono.defer(() -> {
HttpRequest request = context.getHttpRequest();

// If this policy doesn't have an authorityScope cached try to get it from the static challenge cache.
if (this.scope == null) {
// If this policy doesn't have challenge parameters cached try to get it from the static challenge cache.
if (this.challenge == null) {
String authority = getRequestAuthority(request);
this.scope = SCOPE_CACHE.get(authority);
this.challenge = CHALLENGE_CACHE.get(authority);
}

if (this.scope != null) {
// We fetched the scope from the cache, but we have not initialized the scopes in the base yet.
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
if (this.challenge != null) {
// We fetched the challenge from the cache, but we have not initialized the scopes in the base yet.
TokenRequestContext tokenRequestContext = new TokenRequestContext()
.addScopes(this.challenge.getScopes())
.setTenantId(this.challenge.getTenantId());

return setAuthorizationHeader(context, tokenRequestContext);
}
Expand Down Expand Up @@ -150,33 +155,92 @@ public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context
}

if (scope == null) {
this.scope = SCOPE_CACHE.get(authority);
this.challenge = CHALLENGE_CACHE.get(authority);

if (this.scope == null) {
if (this.challenge == null) {
return Mono.just(false);
}
} else {
this.scope = scope;
String authorization = challengeAttributes.get("authorization");

if (authorization == null) {
authorization = challengeAttributes.get("authorization_uri");
}

SCOPE_CACHE.put(authority, this.scope);
final URI authorizationUri;

try {
authorizationUri = new URI(authorization);
} catch (URISyntaxException e) {
// The challenge authorization URI is invalid.
return Mono.just(false);
}

this.challenge = new ChallengeParameters(authorizationUri, new String[] { scope });

CHALLENGE_CACHE.put(authority, this.challenge);
}

TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
TokenRequestContext tokenRequestContext = new TokenRequestContext()
.addScopes(this.challenge.getScopes())
.setTenantId(this.challenge.getTenantId());

return setAuthorizationHeader(context, tokenRequestContext)
.then(Mono.just(true));
});
}

static void clearCache() {
SCOPE_CACHE.clear();
private static class ChallengeParameters {
private final URI authorizationUri;
private final String tenantId;
private final String[] scopes;

ChallengeParameters(URI authorizationUri, String[] scopes) {
this.authorizationUri = authorizationUri;
tenantId = authorizationUri.getPath().split("/")[1];
this.scopes = scopes;
}

/**
* Get the {@code authorization} or {@code authorization_uri} parameter from the challenge response.
*/
public URI getAuthorizationUri() {
return authorizationUri;
}

/**
* Get the {@code resource} or {@code scope} parameter from the challenge response. This should end with
* "/.default".
*/
public String[] getScopes() {
return scopes;
}

/**
* Get the tenant ID from {@code authorizationUri}.
*/
public String getTenantId() {
return tenantId;
}
}

public static void clearCache() {
CHALLENGE_CACHE.clear();
}

/**
* Gets the host name and port of the Key Vault or Managed HSM endpoint.
*
* @param request The {@link HttpRequest} to extract the host name and port from.
*
* @return The host name and port of the Key Vault or Managed HSM endpoint.
*/
private static String getRequestAuthority(HttpRequest request) {
URL url = request.getUrl();
String authority = url.getAuthority();
int port = url.getPort();

// Append port for complete authority.
if (!authority.contains(":") && port > 0) {
authority = authority + ":" + port;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added support for multi-tenant authentication in clients.

### Breaking Changes

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.Collections;
Expand All @@ -34,8 +36,8 @@ public class KeyVaultCredentialPolicy extends BearerTokenAuthenticationPolicy {
private static final String KEY_VAULT_STASHED_CONTENT_KEY = "KeyVaultCredentialPolicyStashedBody";
private static final String KEY_VAULT_STASHED_CONTENT_LENGTH_KEY = "KeyVaultCredentialPolicyStashedContentLength";
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
private static final ConcurrentMap<String, String> SCOPE_CACHE = new ConcurrentHashMap<>();
private String scope;
private static final ConcurrentMap<String, ChallengeParameters> CHALLENGE_CACHE = new ConcurrentHashMap<>();
private ChallengeParameters challenge;

/**
* Creates a {@link KeyVaultCredentialPolicy}.
Expand Down Expand Up @@ -80,6 +82,7 @@ private static Map<String, String> extractChallengeAttributes(String authenticat
*
* @param authenticateHeader The authentication header containing all the challenges.
* @param authChallengePrefix The authentication challenge name.
*
* @return A boolean indicating if the challenge is a bearer challenge or not.
*/
private static boolean isBearerChallenge(String authenticateHeader, String authChallengePrefix) {
Expand All @@ -92,15 +95,17 @@ public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {
return Mono.defer(() -> {
HttpRequest request = context.getHttpRequest();

// If this policy doesn't have an authorityScope cached try to get it from the static challenge cache.
if (this.scope == null) {
// If this policy doesn't have challenge parameters cached try to get it from the static challenge cache.
if (this.challenge == null) {
String authority = getRequestAuthority(request);
this.scope = SCOPE_CACHE.get(authority);
this.challenge = CHALLENGE_CACHE.get(authority);
}

if (this.scope != null) {
// We fetched the scope from the cache, but we have not initialized the scopes in the base yet.
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
if (this.challenge != null) {
// We fetched the challenge from the cache, but we have not initialized the scopes in the base yet.
TokenRequestContext tokenRequestContext = new TokenRequestContext()
.addScopes(this.challenge.getScopes())
.setTenantId(this.challenge.getTenantId());

return setAuthorizationHeader(context, tokenRequestContext);
}
Expand Down Expand Up @@ -150,33 +155,92 @@ public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context
}

if (scope == null) {
this.scope = SCOPE_CACHE.get(authority);
this.challenge = CHALLENGE_CACHE.get(authority);

if (this.scope == null) {
if (this.challenge == null) {
return Mono.just(false);
}
} else {
this.scope = scope;
String authorization = challengeAttributes.get("authorization");

if (authorization == null) {
authorization = challengeAttributes.get("authorization_uri");
}

SCOPE_CACHE.put(authority, this.scope);
final URI authorizationUri;

try {
authorizationUri = new URI(authorization);
} catch (URISyntaxException e) {
// The challenge authorization URI is invalid.
return Mono.just(false);
}

this.challenge = new ChallengeParameters(authorizationUri, new String[] { scope });

CHALLENGE_CACHE.put(authority, this.challenge);
}

TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);
TokenRequestContext tokenRequestContext = new TokenRequestContext()
.addScopes(this.challenge.getScopes())
.setTenantId(this.challenge.getTenantId());

return setAuthorizationHeader(context, tokenRequestContext)
.then(Mono.just(true));
});
}

static void clearCache() {
SCOPE_CACHE.clear();
private static class ChallengeParameters {
private final URI authorizationUri;
private final String tenantId;
private final String[] scopes;

ChallengeParameters(URI authorizationUri, String[] scopes) {
this.authorizationUri = authorizationUri;
tenantId = authorizationUri.getPath().split("/")[1];
this.scopes = scopes;
}

/**
* Get the {@code authorization} or {@code authorization_uri} parameter from the challenge response.
*/
public URI getAuthorizationUri() {
return authorizationUri;
}

/**
* Get the {@code resource} or {@code scope} parameter from the challenge response. This should end with
* "/.default".
*/
public String[] getScopes() {
return scopes;
}

/**
* Get the tenant ID from {@code authorizationUri}.
*/
public String getTenantId() {
return tenantId;
}
}

public static void clearCache() {
CHALLENGE_CACHE.clear();
}

/**
* Gets the host name and port of the Key Vault or Managed HSM endpoint.
*
* @param request The {@link HttpRequest} to extract the host name and port from.
*
* @return The host name and port of the Key Vault or Managed HSM endpoint.
*/
private static String getRequestAuthority(HttpRequest request) {
URL url = request.getUrl();
String authority = url.getAuthority();
int port = url.getPort();

// Append port for complete authority.
if (!authority.contains(":") && port > 0) {
authority = authority + ":" + port;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.azure.core.util.Context;
import com.azure.core.util.polling.PollResponse;
import com.azure.core.util.polling.SyncPoller;
import com.azure.security.keyvault.certificates.implementation.KeyVaultCredentialPolicy;
import com.azure.security.keyvault.certificates.models.CertificateContact;
import com.azure.security.keyvault.certificates.models.CertificateIssuer;
import com.azure.security.keyvault.certificates.models.CertificateContentType;
Expand All @@ -35,6 +36,8 @@
import java.util.HashMap;
import java.util.Arrays;
import java.util.HashSet;
import java.util.UUID;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand All @@ -52,7 +55,12 @@ protected void beforeTest() {
}

private void createCertificateClient(HttpClient httpClient, CertificateServiceVersion serviceVersion) {
HttpPipeline httpPipeline = getHttpPipeline(httpClient);
createCertificateClient(httpClient, serviceVersion, null);
}

private void createCertificateClient(HttpClient httpClient, CertificateServiceVersion serviceVersion,
String testTenantId) {
HttpPipeline httpPipeline = getHttpPipeline(httpClient, testTenantId);
CertificateAsyncClient asyncClient = spy(new CertificateClientBuilder()
.vaultUrl(getEndpoint())
.pipeline(httpPipeline)
Expand Down Expand Up @@ -81,6 +89,31 @@ public void createCertificate(HttpClient httpClient, CertificateServiceVersion s
});
}

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("getTestParameters")
public void createCertificateWithMultipleTenants(HttpClient httpClient, CertificateServiceVersion serviceVersion) {
createCertificateClient(httpClient, serviceVersion, testResourceNamer.randomUuid());
createCertificateRunner((policy) -> {
String certName = generateResourceId("testCer");
SyncPoller<CertificateOperation, KeyVaultCertificateWithPolicy> certPoller =
client.beginCreateCertificate(certName, policy);
certPoller.waitForCompletion();
KeyVaultCertificateWithPolicy expected = certPoller.getFinalResult();
assertEquals(certName, expected.getName());
assertNotNull(expected.getProperties().getCreatedOn());
});
KeyVaultCredentialPolicy.clearCache(); // Ensure we don't have anything cached and try again.
createCertificateRunner((policy) -> {
String certName = generateResourceId("testCer2");
SyncPoller<CertificateOperation, KeyVaultCertificateWithPolicy> certPoller =
client.beginCreateCertificate(certName, policy);
certPoller.waitForCompletion();
KeyVaultCertificateWithPolicy expected = certPoller.getFinalResult();
assertEquals(certName, expected.getName());
assertNotNull(expected.getProperties().getCreatedOn());
});
}

private void deleteAndPurgeCertificate(String certName) {
SyncPoller<DeletedCertificate, Void> deletePoller = client.beginDeleteCertificate(certName);
deletePoller.poll();
Expand Down
Loading

0 comments on commit 60f92e8

Please sign in to comment.