Skip to content

Commit

Permalink
Merge pull request #886 from AzureAD/avdunn/tenant-override-fix
Browse files Browse the repository at this point in the history
Pass optional tenant override to internal silent call
  • Loading branch information
Avery-Dunn authored Dec 19, 2024
2 parents 95b5efc + a1a394a commit 920d5d2
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ void acquireTokenWithOBO_Managed(String environment) throws Exception {
new UserAssertion(accessToken)).build()).
get();

assertNotNull(result);
assertNotNull(result.accessToken());
assertResultNotNull(result);
}

@ParameterizedTest
Expand All @@ -63,8 +62,7 @@ void acquireTokenWithOBO_testCache(String environment) throws Exception {
new UserAssertion(accessToken)).build()).
get();

assertNotNull(result1);
assertNotNull(result1.accessToken());
assertResultNotNull(result1);

// Same scope and userAssertion, should return cached tokens
IAuthenticationResult result2 =
Expand All @@ -82,8 +80,7 @@ void acquireTokenWithOBO_testCache(String environment) throws Exception {
new UserAssertion(accessToken)).build()).
get();

assertNotNull(result3);
assertNotNull(result3.accessToken());
assertResultNotNull(result3);
assertNotEquals(result2.accessToken(), result3.accessToken());

// Scope 2, should return cached token
Expand All @@ -105,8 +102,7 @@ void acquireTokenWithOBO_testCache(String environment) throws Exception {
.build()).
get();

assertNotNull(result5);
assertNotNull(result5.accessToken());
assertResultNotNull(result5);
assertNotEquals(result5.accessToken(), result4.accessToken());
assertNotEquals(result5.accessToken(), result2.accessToken());

Expand All @@ -121,13 +117,17 @@ void acquireTokenWithOBO_testCache(String environment) throws Exception {
.build()).
get();

assertNotNull(result6);
assertNotNull(result6.accessToken());
assertResultNotNull(result6);
assertNotEquals(result6.accessToken(), result5.accessToken());
assertNotEquals(result6.accessToken(), result4.accessToken());
assertNotEquals(result6.accessToken(), result2.accessToken());
}

private void assertResultNotNull(IAuthenticationResult result) {
assertNotNull(result);
assertNotNull(result.accessToken());
}

private String getAccessToken() throws Exception {

LabUserProvider labUserProvider = LabUserProvider.getInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ AuthenticationResult execute() throws Exception {
SilentParameters parameters = SilentParameters
.builder(this.clientCredentialRequest.parameters.scopes())
.claims(this.clientCredentialRequest.parameters.claims())
.tenant(this.clientCredentialRequest.parameters.tenant())
.build();

RequestContext context = new RequestContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ AuthenticationResult execute() throws Exception {
SilentParameters parameters = SilentParameters
.builder(this.onBehalfOfRequest.parameters.scopes())
.claims(this.onBehalfOfRequest.parameters.claims())
.tenant(this.onBehalfOfRequest.parameters.tenant())
.build();

RequestContext context = new RequestContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@

package com.microsoft.aad.msal4j;

import java.util.Collections;
import java.util.HashMap;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class ClientCredentialTest {
Expand All @@ -32,4 +42,69 @@ void testSecretNullAndEmpty() {

assertTrue(ex.getMessage().contains("clientSecret is null or empty"));
}

@Test
void OnBehalfOf_InternalCacheLookup_Success() throws Exception {
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(any(HttpRequest.class))).thenReturn(TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(new HashMap<>())));

ConfidentialClientApplication cca =
ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromSecret("password"))
.authority("https://login.microsoftonline.com/tenant/")
.instanceDiscovery(false)
.validateAuthority(false)
.httpClient(httpClientMock)
.build();

ClientCredentialParameters parameters = ClientCredentialParameters.builder(Collections.singleton("scopes")).build();

IAuthenticationResult result = cca.acquireToken(parameters).get();
IAuthenticationResult result2 = cca.acquireToken(parameters).get();

//OBO flow should perform an internal cache lookup, so similar parameters should only cause one HTTP client call
assertEquals(result.accessToken(), result2.accessToken());
verify(httpClientMock, times(1)).send(any());
}

@Test
void OnBehalfOf_TenantOverride() throws Exception {
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

ConfidentialClientApplication cca =
ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromSecret("password"))
.authority("https://login.microsoftonline.com/tenant")
.instanceDiscovery(false)
.validateAuthority(false)
.httpClient(httpClientMock)
.build();

HashMap<String, String> tokenResponseValues = new HashMap<>();
tokenResponseValues.put("access_token", "accessTokenFirstCall");

when(httpClientMock.send(any(HttpRequest.class))).thenReturn(TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)));
ClientCredentialParameters parameters = ClientCredentialParameters.builder(Collections.singleton("scopes")).build();

//The two acquireToken calls have the same parameters...
IAuthenticationResult resultAppLevelTenant = cca.acquireToken(parameters).get();
IAuthenticationResult resultAppLevelTenantCached = cca.acquireToken(parameters).get();
//...so only one token should be added to the cache, and the mocked HTTP client's "send" method should only have been called once
assertEquals(1, cca.tokenCache.accessTokens.size());
assertEquals(resultAppLevelTenant.accessToken(), resultAppLevelTenantCached.accessToken());
verify(httpClientMock, times(1)).send(any());

tokenResponseValues.put("access_token", "accessTokenSecondCall");

when(httpClientMock.send(any(HttpRequest.class))).thenReturn(TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)));
parameters = ClientCredentialParameters.builder(Collections.singleton("scopes")).tenant("otherTenant").build();

//Overriding the tenant parameter in the request should lead to a new token call being made...
IAuthenticationResult resultRequestLevelTenant = cca.acquireToken(parameters).get();
IAuthenticationResult resultRequestLevelTenantCached = cca.acquireToken(parameters).get();
//...which should be different from the original token, and thus the cache should have two tokens created from two HTTP calls
assertEquals(2, cca.tokenCache.accessTokens.size());
assertEquals(resultRequestLevelTenant.accessToken(), resultRequestLevelTenantCached.accessToken());
assertNotEquals(resultAppLevelTenant.accessToken(), resultRequestLevelTenant.accessToken());
verify(httpClientMock, times(2)).send(any());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import java.util.Collections;
import java.util.HashMap;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class OnBehalfOfTests {

@Test
void OnBehalfOf_InternalCacheLookup_Success() throws Exception {
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(any(HttpRequest.class))).thenReturn(TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(new HashMap<>())));

ConfidentialClientApplication cca =
ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromSecret("password"))
.authority("https://login.microsoftonline.com/tenant/")
.instanceDiscovery(false)
.validateAuthority(false)
.httpClient(httpClientMock)
.build();

OnBehalfOfParameters parameters = OnBehalfOfParameters.builder(Collections.singleton("scopes"), new UserAssertion(TestHelper.signedAssertion)).build();

IAuthenticationResult result = cca.acquireToken(parameters).get();
IAuthenticationResult result2 = cca.acquireToken(parameters).get();

//OBO flow should perform an internal cache lookup, so similar parameters should only cause one HTTP client call
assertEquals(result.accessToken(), result2.accessToken());
verify(httpClientMock, times(1)).send(any());
}

@Test
void OnBehalfOf_TenantOverride() throws Exception {
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

ConfidentialClientApplication cca =
ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromSecret("password"))
.authority("https://login.microsoftonline.com/tenant")
.instanceDiscovery(false)
.validateAuthority(false)
.httpClient(httpClientMock)
.build();

HashMap<String, String> tokenResponseValues = new HashMap<>();
tokenResponseValues.put("access_token", "accessTokenFirstCall");

when(httpClientMock.send(any(HttpRequest.class))).thenReturn(TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)));
OnBehalfOfParameters parameters = OnBehalfOfParameters.builder(Collections.singleton("scopes"), new UserAssertion(TestHelper.signedAssertion)).build();

//The two acquireToken calls have the same parameters...
IAuthenticationResult resultAppLevelTenant = cca.acquireToken(parameters).get();
IAuthenticationResult resultAppLevelTenantCached = cca.acquireToken(parameters).get();
//...so only one token should be added to the cache, and the mocked HTTP client's "send" method should only have been called once
assertEquals(1, cca.tokenCache.accessTokens.size());
assertEquals(resultAppLevelTenant.accessToken(), resultAppLevelTenantCached.accessToken());
verify(httpClientMock, times(1)).send(any());

tokenResponseValues.put("access_token", "accessTokenSecondCall");

when(httpClientMock.send(any(HttpRequest.class))).thenReturn(TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues)));
parameters = OnBehalfOfParameters.builder(Collections.singleton("scopes"), new UserAssertion(TestHelper.signedAssertion)).tenant("otherTenant").build();

//Overriding the tenant parameter in the request should lead to a new token call being made...
IAuthenticationResult resultRequestLevelTenant = cca.acquireToken(parameters).get();
IAuthenticationResult resultRequestLevelTenantCached = cca.acquireToken(parameters).get();
//...which should be different from the original token, and thus the cache should have two tokens created from two HTTP calls
assertEquals(2, cca.tokenCache.accessTokens.size());
assertEquals(resultRequestLevelTenant.accessToken(), resultRequestLevelTenantCached.accessToken());
assertNotEquals(resultAppLevelTenant.accessToken(), resultRequestLevelTenant.accessToken());
verify(httpClientMock, times(2)).send(any());
}
}
78 changes: 74 additions & 4 deletions msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,38 @@

package com.microsoft.aad.msal4j;

import com.nimbusds.jose.*;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

class TestHelper {

static String readResource(Class<?> classInstance, String resource) throws IOException, URISyntaxException {
return new String(
Files.readAllBytes(
Paths.get(classInstance.getResource(resource).toURI())));
//Signed JWT which should be enough to pass the parsing/validation in the library, useful if a unit test needs an
// assertion but that is not the focus of the test
static String signedAssertion = generateToken();
private static final String successfulResponseFormat = "{\"access_token\":\"%s\",\"id_token\":\"%s\",\"refresh_token\":\"%s\"," +
"\"client_id\":\"%s\",\"client_info\":\"%s\"," +
"\"expires_on\": %d ,\"expires_in\": %d," +
"\"token_type\":\"Bearer\"}";

static String readResource(Class<?> classInstance, String resource) {
try {
return new String(Files.readAllBytes(Paths.get(classInstance.getResource(resource).toURI())));
} catch (IOException | URISyntaxException e) {
throw new RuntimeException(e);
}
}

static void deleteFileContent(Class<?> classInstance, String resource)
Expand All @@ -27,4 +46,55 @@ static void deleteFileContent(Class<?> classInstance, String resource)
fileWriter.write("");
fileWriter.close();
}

static String generateToken() {
try {
RSAKey rsaJWK = new RSAKeyGenerator(2048)
.keyID("kid")
.generate();
JWSObject jwsObject = new JWSObject(
new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(),
new Payload("payload"));

jwsObject.sign(new RSASSASigner(rsaJWK));

return jwsObject.serialize();
} catch (JOSEException e) {
throw new RuntimeException(e);
}
}

//Maps various values to the successfulResponseFormat string to create a valid token response
static String getSuccessfulTokenResponse(HashMap<String, String> responseValues) {
//Will default to expiring in one hour if expiry time values are not set
long expiresIn = responseValues.containsKey("expires_in") ?
Long.parseLong(responseValues.get("expires_in")) :
3600;
long expiresOn = responseValues.containsKey("expires_on")
? Long.parseLong(responseValues.get("expires_0n")) :
(System.currentTimeMillis() / 1000) + expiresIn;

return String.format(successfulResponseFormat,
responseValues.getOrDefault("access_token", "access_token"),
responseValues.getOrDefault("id_token", "id_token"),
responseValues.getOrDefault("refresh_token", "refresh_token"),
responseValues.getOrDefault("client_id", "client_id"),
responseValues.getOrDefault("client_info", "client_info"),
expiresOn,
expiresIn
);
}

//Creates a valid HttpResponse that can be used when mocking HttpClient.send()
static HttpResponse expectedResponse(int statusCode, String response) {
Map<String, List<String>> headers = new HashMap<>();
headers.put("Content-Type", Collections.singletonList("application/json"));

HttpResponse httpResponse = new HttpResponse();
httpResponse.statusCode(statusCode);
httpResponse.body(response);
httpResponse.addHeaders(headers);

return httpResponse;
}
}

0 comments on commit 920d5d2

Please sign in to comment.