Skip to content

Commit

Permalink
Update Token Exchange Auth Flow (#23946)
Browse files Browse the repository at this point in the history
  • Loading branch information
g2vinay authored Sep 13, 2021
1 parent 36f2e56 commit 3b57be8
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 35 deletions.
3 changes: 2 additions & 1 deletion sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Release History

## 1.4.0-beta.1 (2021-08-16)
## 1.4.0-beta.1 (2021-09-13)
### Features Added

- Added support to `ManagedIdentityCredential` for Bridge to Kubernetes local development authentication.
Expand All @@ -11,6 +11,7 @@
- A region can also be specified through the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable.
- Added `loginHint()` setter to `InteractiveBrowserCredentialBuilder` which allows a username to be pre-selected for interactive logins.
- Added support to consume `TenantId` challenges from `TokenRequestContext`.
- Added support for AKS Token Exchange support in `ManagedIdentityCredential`


## 1.3.6 (2021-09-08)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.implementation.IdentityClient;

import reactor.core.publisher.Mono;
Expand All @@ -13,6 +14,7 @@
* Authenticates a service principal with AAD using a client assertion.
*/
class ClientAssertionCredential extends ManagedIdentityServiceCredential {
private final ClientLogger logger = new ClientLogger(ClientAssertionCredential.class);

/**
* Creates an instance of ClientAssertionCredential.
Expand All @@ -26,6 +28,11 @@ class ClientAssertionCredential extends ManagedIdentityServiceCredential {

@Override
public Mono<AccessToken> authenticate(TokenRequestContext request) {
if (this.getClientId() == null) {
return Mono.error(logger.logExceptionAsError(new IllegalStateException("The client id is not configured via"
+ " 'AZURE_CLIENT_ID' environment variable or through the credential builder."
+ " Please ensure client id is provided to authenticate via token exchange in AKS environment.")));
}
return identityClient.authenticatewithExchangeToken(request);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.azure.identity.implementation.util.LoggingUtil;
import reactor.core.publisher.Mono;

import java.time.Duration;

/**
* The base class for Managed Service Identity token based credentials.
*/
Expand All @@ -24,7 +26,7 @@ public final class ManagedIdentityCredential implements TokenCredential {

static final String PROPERTY_IMDS_ENDPOINT = "IMDS_ENDPOINT";
static final String PROPERTY_IDENTITY_SERVER_THUMBPRINT = "IDENTITY_SERVER_THUMBPRINT";
static final String TOKEN_FILE_PATH = "TOKEN_FILE_PATH";
static final String AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE";


/**
Expand Down Expand Up @@ -53,13 +55,15 @@ public final class ManagedIdentityCredential implements TokenCredential {
} else {
managedIdentityServiceCredential = new VirtualMachineMsiCredential(clientId, clientBuilder.build());
}
} else if (configuration.contains(Configuration.PROPERTY_AZURE_CLIENT_ID)
&& configuration.contains(Configuration.PROPERTY_AZURE_TENANT_ID)
&& configuration.get(TOKEN_FILE_PATH) != null) {
} else if (configuration.contains(Configuration.PROPERTY_AZURE_TENANT_ID)
&& configuration.get(AZURE_FEDERATED_TOKEN_FILE) != null) {
String clientIdentifier = clientId == null
? configuration.get(Configuration.PROPERTY_AZURE_CLIENT_ID) : clientId;
clientBuilder.clientId(clientIdentifier);
clientBuilder.tenantId(configuration.get(Configuration.PROPERTY_AZURE_TENANT_ID));
clientBuilder.clientAssertionPath(configuration.get(TOKEN_FILE_PATH));
managedIdentityServiceCredential = new ClientAssertionCredential(clientId, clientBuilder.build());

clientBuilder.clientAssertionPath(configuration.get(AZURE_FEDERATED_TOKEN_FILE));
clientBuilder.clientAssertionTimeout(Duration.ofMinutes(5));
managedIdentityServiceCredential = new ClientAssertionCredential(clientIdentifier, clientBuilder.build());
} else {
managedIdentityServiceCredential = new VirtualMachineMsiCredential(clientId, clientBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import com.azure.identity.implementation.util.LoggingUtil;
import reactor.core.publisher.Mono;

import java.time.Duration;

/**
* An AAD credential that acquires a token with a client secret and user assertion for an AAD application
* on behalf of a user principal.
Expand Down Expand Up @@ -43,7 +41,6 @@ public OnBehalfOfCredential(String clientId, String tenantId, String clientSecre
.certificatePath(certificatePath)
.certificatePassword(certificatePassword)
.identityClientOptions(identityClientOptions)
.confidentialClientCacheTimeout(Duration.ofMinutes(5))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
Expand Down Expand Up @@ -130,6 +131,8 @@ public class IdentityClient {
private HttpPipelineAdapter httpPipelineAdapter;
private final SynchronizedAccessor<PublicClientApplication> publicClientApplicationAccessor;
private final SynchronizedAccessor<ConfidentialClientApplication> confidentialClientApplicationAccessor;
private final SynchronizedAccessor<String> clientAssertionAccessor;


/**
* Creates an IdentityClient with the given options.
Expand All @@ -142,12 +145,12 @@ public class IdentityClient {
* @param certificatePassword the password protecting the PFX certificate.
* @param isSharedTokenCacheCredential Indicate whether the credential is
* {@link com.azure.identity.SharedTokenCacheCredential} or not.
* @param confidentialClientCacheTimeout the cache time out to use for confidential client.
* @param clientAssertionTimeout the time out to use for the client assertion.
* @param options the options configuring the client.
*/
IdentityClient(String tenantId, String clientId, String clientSecret, String certificatePath,
String clientAssertionFilePath, InputStream certificate, String certificatePassword,
boolean isSharedTokenCacheCredential, Duration confidentialClientCacheTimeout,
boolean isSharedTokenCacheCredential, Duration clientAssertionTimeout,
IdentityClientOptions options) {
if (tenantId == null) {
tenantId = "organizations";
Expand All @@ -167,9 +170,12 @@ public class IdentityClient {
this.publicClientApplicationAccessor = new SynchronizedAccessor<>(() ->
getPublicClientApplication(isSharedTokenCacheCredential));

this.confidentialClientApplicationAccessor = confidentialClientCacheTimeout == null
? new SynchronizedAccessor<>(() -> getConfidentialClientApplication())
: new SynchronizedAccessor<>(() -> getConfidentialClientApplication(), confidentialClientCacheTimeout);
this.confidentialClientApplicationAccessor = new SynchronizedAccessor<>(() ->
getConfidentialClientApplication());

this.clientAssertionAccessor = clientAssertionTimeout == null
? new SynchronizedAccessor<>(() -> parseClientAssertion(), Duration.ofMinutes(5))
: new SynchronizedAccessor<>(() -> parseClientAssertion(), clientAssertionTimeout);
}

private Mono<ConfidentialClientApplication> getConfidentialClientApplication() {
Expand Down Expand Up @@ -211,15 +217,6 @@ private Mono<ConfidentialClientApplication> getConfidentialClientApplication() {
return Mono.error(logger.logExceptionAsError(new RuntimeException(
"Failed to parse the certificate for the credential: " + e.getMessage(), e)));
}
} else if (clientAssertionFilePath != null) {
try {
credential = ClientCredentialFactory
.createFromClientAssertion(parseClientAssertion(clientAssertionFilePath));
} catch (IOException e) {
return Mono.error(logger.logExceptionAsError(new RuntimeException(
"Failed to parse the client assertion from the provided file: " + clientAssertionFilePath
+ ". " + e.getMessage(), e)));
}
} else {
return Mono.error(logger.logExceptionAsError(
new IllegalArgumentException("Must provide client secret or client certificate path")));
Expand Down Expand Up @@ -271,9 +268,19 @@ private Mono<ConfidentialClientApplication> getConfidentialClientApplication() {
});
}

private String parseClientAssertion(String clientAssertionFilePath) throws IOException {
byte[] encoded = Files.readAllBytes(Paths.get(clientAssertionFilePath));
return new String(encoded, StandardCharsets.UTF_8);
private Mono<String> parseClientAssertion() {
return Mono.fromCallable(() -> {
if (clientAssertionFilePath != null) {
byte[] encoded = Files.readAllBytes(Paths.get(clientAssertionFilePath));
return new String(encoded, StandardCharsets.UTF_8);
} else {
throw logger.logExceptionAsError(new IllegalStateException(
"Client Assertion File Path is not provided."
+ " It should be provided to authenticate with client assertion."
));
}

});
}

private Mono<PublicClientApplication> getPublicClientApplication(boolean sharedTokenCacheCredential) {
Expand Down Expand Up @@ -1038,7 +1045,52 @@ public Mono<AccessToken> authenticateToArcManagedIdentityEndpoint(String identit
* @return a Publisher that emits an AccessToken
*/
public Mono<AccessToken> authenticatewithExchangeToken(TokenRequestContext request) {
return authenticateWithConfidentialClient(request);

return clientAssertionAccessor.getValue()
.flatMap(assertionToken -> Mono.fromCallable(() -> {
String authorityUrl = options.getAuthorityHost().replaceAll("/+$", "")
+ "/" + tenantId + "/oauth2/v2.0/token";

StringBuilder urlParametersBuilder = new StringBuilder();
urlParametersBuilder.append("client_assertion=");
urlParametersBuilder.append(assertionToken);
urlParametersBuilder.append("&client_assertion_type=urn:ietf:params:oauth:client-assertion-type"
+ ":jwt-bearer");
urlParametersBuilder.append("&client_id=");
urlParametersBuilder.append(clientId);
urlParametersBuilder.append("&grant_type=client_credentials");
urlParametersBuilder.append("&scope=");
urlParametersBuilder.append(URLEncoder.encode(request.getScopes().get(0), "UTF-8"));

String urlParams = urlParametersBuilder.toString();

byte[] postData = urlParams.getBytes(StandardCharsets.UTF_8);
int postDataLength = postData.length;

HttpURLConnection connection = null;

URL url = new URL(authorityUrl);

try {
connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded");
connection.setRequestProperty("Content-Length", Integer.toString(postDataLength));
connection.setDoOutput(true);
try (DataOutputStream outputStream = new DataOutputStream(connection.getOutputStream())) {
outputStream.write(postData);
}
connection.connect();

Scanner s = new Scanner(connection.getInputStream(), "UTF-8").useDelimiter("\\A");
String result = s.hasNext() ? s.next() : "";
return SERIALIZER_ADAPTER.deserialize(result, MSIToken.class, SerializerEncoding.JSON);
} finally {
if (connection != null) {
connection.disconnect();
}
}
}));
}

/**
Expand All @@ -1054,7 +1106,6 @@ public Mono<AccessToken> authenticateToServiceFabricManagedIdentityEndpoint(Stri
String thumbprint,
TokenRequestContext request) {
return Mono.fromCallable(() -> {

HttpsURLConnection connection = null;
String endpoint = identityEndpoint;
String headerValue = identityHeader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public final class IdentityClientBuilder {
private InputStream certificate;
private String certificatePassword;
private boolean sharedTokenCacheCred;
private Duration confidentialClientCacheTimeout;
private Duration clientAssertionTimeout;

/**
* Sets the tenant ID for the client.
Expand Down Expand Up @@ -123,11 +123,12 @@ public IdentityClientBuilder sharedTokenCacheCredential(boolean isSharedTokenCac
/**
* Configure the time out to use re-use confidential client for. Post time out, a new instance of client is created.
*
* @param confidentialClientCacheTimeout the time out to use for confidential client cache.
* @param clientAssertionTimeout the time out to use for the client assertion configured via
* {@link IdentityClientBuilder#clientAssertionPath(String)}.
* @return the updated IdentityClientBuilder.
*/
public IdentityClientBuilder confidentialClientCacheTimeout(Duration confidentialClientCacheTimeout) {
this.confidentialClientCacheTimeout = confidentialClientCacheTimeout;
public IdentityClientBuilder clientAssertionTimeout(Duration clientAssertionTimeout) {
this.clientAssertionTimeout = clientAssertionTimeout;
return this;
}

Expand All @@ -136,6 +137,6 @@ public IdentityClientBuilder confidentialClientCacheTimeout(Duration confidentia
*/
public IdentityClient build() {
return new IdentityClient(tenantId, clientId, clientSecret, certificatePath, clientAssertionPath, certificate,
certificatePassword, sharedTokenCacheCred, confidentialClientCacheTimeout, identityClientOptions);
certificatePassword, sharedTokenCacheCred, clientAssertionTimeout, identityClientOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import com.azure.core.credential.AccessToken;
import com.azure.core.util.logging.ClientLogger;
import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

Expand All @@ -28,6 +29,7 @@ public final class MSIToken extends AccessToken {
private String accessToken;

@JsonProperty(value = "expires_on")
@JsonAlias("expires_in")
private String expiresOn;

/**
Expand Down

1 comment on commit 3b57be8

@rudolfbartel
Copy link

@rudolfbartel rudolfbartel commented on 3b57be8 Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @g2vinay, the line 32 in the MSIToken class causes wrong mapping from the json token resulting token expiration in the class AccessTokenCache. Please check the issue #25598

Please sign in to comment.