Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added token cache map to fix use of unintended auth token for subsequent connections #2341

Merged
merged 10 commits into from
Mar 20, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
* @see <a href="https://aka.ms/msal4j-token-cache">https://aka.ms/msal4j-token-cache</a>
*/
public class PersistentTokenCacheAccessAspect implements ITokenCacheAccessAspect {
private static PersistentTokenCacheAccessAspect instance = new PersistentTokenCacheAccessAspect();

private static PersistentTokenCacheAccessAspect instance;
private final Lock lock = new ReentrantLock();

private PersistentTokenCacheAccessAspect() {}
static final long TIME_TO_LIVE = 86400000L; // Token cache time to live (24 hrs).
tkyc marked this conversation as resolved.
Show resolved Hide resolved
private long expiryTime;

static PersistentTokenCacheAccessAspect getInstance() {
if (instance == null) {
instance = new PersistentTokenCacheAccessAspect();
}
return instance;
}

Expand Down Expand Up @@ -62,6 +65,14 @@ public void afterCacheAccess(ITokenCacheAccessContext iTokenCacheAccessContext)

}

public long getExpiryTime() {
return this.expiryTime;
}

public void setExpiryTime(long expiryTime) {
this.expiryTime = expiryTime;
}

/**
* Clears User token cache. This will clear all account info so interactive login will be required on the next
* request to acquire an access token.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import java.net.URI;
import java.net.URISyntaxException;

import java.security.MessageDigest;
import java.text.MessageFormat;

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -53,6 +55,7 @@

import com.microsoft.sqlserver.jdbc.SQLServerConnection.ActiveDirectoryAuthentication;
import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo;
import mssql.googlecode.concurrentlinkedhashmap.ConcurrentLinkedHashMap;
tkyc marked this conversation as resolved.
Show resolved Hide resolved


class SQLServerMSAL4JUtils {
Expand All @@ -61,6 +64,8 @@ class SQLServerMSAL4JUtils {
static final String SLASH_DEFAULT = "/.default";
static final String ACCESS_TOKEN_EXPIRE = "access token expires: ";

private static final TokenCacheMap tokenCacheMap = new TokenCacheMap();
tkyc marked this conversation as resolved.
Show resolved Hide resolved

private final static String LOGCONTEXT = "MSAL version "
+ com.microsoft.aad.msal4j.PublicClientApplication.class.getPackage().getImplementationVersion() + ": ";

Expand All @@ -84,10 +89,17 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str
lock.lock();

try {
String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password});
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect;

if (null == (persistentTokenCacheAccessAspect = tokenCacheMap.getEntry(hashedSecret))) {
persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
tokenCacheMap.addEntry(hashedSecret, persistentTokenCacheAccessAspect);
}
tkyc marked this conversation as resolved.
Show resolved Hide resolved

final PublicClientApplication pca = PublicClientApplication
.builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.authority(fedAuthInfo.stsurl).build();
.setTokenCacheAccessAspect(persistentTokenCacheAccessAspect).authority(fedAuthInfo.stsurl).build();

final CompletableFuture<IAuthenticationResult> future = pca.acquireToken(UserNamePasswordParameters
.builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray())
Expand Down Expand Up @@ -132,11 +144,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth
lock.lock();

try {
String hashedSecret = getHashedSecret(
new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret});
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect;

if (null == (persistentTokenCacheAccessAspect = tokenCacheMap.getEntry(hashedSecret))) {
persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
tokenCacheMap.addEntry(hashedSecret, persistentTokenCacheAccessAspect);
}
tkyc marked this conversation as resolved.
Show resolved Hide resolved

IClientCredential credential = ClientCredentialFactory.createFromSecret(aadPrincipalSecret);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(aadPrincipalID, credential).executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.authority(fedAuthInfo.stsurl).build();
.setTokenCacheAccessAspect(persistentTokenCacheAccessAspect).authority(fedAuthInfo.stsurl).build();

final CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());
Expand Down Expand Up @@ -181,6 +201,15 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI
lock.lock();

try {
String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile,
certPassword, certKey, certKeyPassword});
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect;

if (null == (persistentTokenCacheAccessAspect = tokenCacheMap.getEntry(hashedSecret))) {
persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
tokenCacheMap.addEntry(hashedSecret, persistentTokenCacheAccessAspect);
}
tkyc marked this conversation as resolved.
Show resolved Hide resolved

ConfidentialClientApplication clientApplication = null;

// check if cert is PKCS12 first
Expand All @@ -202,8 +231,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI

IClientCredential credential = ClientCredentialFactory.createFromCertificate(is, certPassword);
clientApplication = ConfidentialClientApplication.builder(aadPrincipalID, credential)
.executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.executorService(executorService).setTokenCacheAccessAspect(persistentTokenCacheAccessAspect)
.authority(fedAuthInfo.stsurl).build();
} catch (FileNotFoundException e) {
// re-throw if file not there no point to try another format
Expand Down Expand Up @@ -232,8 +260,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI

IClientCredential credential = ClientCredentialFactory.createFromCertificate(privateKey, cert);
clientApplication = ConfidentialClientApplication.builder(aadPrincipalID, credential)
.executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.executorService(executorService).setTokenCacheAccessAspect(persistentTokenCacheAccessAspect)
.authority(fedAuthInfo.stsurl).build();
}

Expand Down Expand Up @@ -449,4 +476,47 @@ private static SQLServerException getCorrectedException(Exception e, String user
return new SQLServerException(form.format(msgArgs), null, 0, correctedExecutionException);
}
}

private static String getHashedSecret(String[] secrets) throws SQLServerException {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
for (String secret : secrets) {
if (null != secret) {
md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
}
}
return new String(md.digest());
} catch (NoSuchAlgorithmException e) {
throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
}
}

private static class TokenCacheMap {
private ConcurrentHashMap<String, PersistentTokenCacheAccessAspect> tokenCacheMap = new ConcurrentHashMap<>();

PersistentTokenCacheAccessAspect getEntry(String key) {
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = tokenCacheMap.get(key);

if (null != persistentTokenCacheAccessAspect) {
if (System.currentTimeMillis() > persistentTokenCacheAccessAspect.getExpiryTime()) {
tokenCacheMap.remove(key);

persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
persistentTokenCacheAccessAspect
.setExpiryTime(System.currentTimeMillis() + PersistentTokenCacheAccessAspect.TIME_TO_LIVE);

tokenCacheMap.put(key, persistentTokenCacheAccessAspect);

return persistentTokenCacheAccessAspect;
tkyc marked this conversation as resolved.
Show resolved Hide resolved
}
}

return persistentTokenCacheAccessAspect;
}

void addEntry(String key, PersistentTokenCacheAccessAspect value) {
value.setExpiryTime(System.currentTimeMillis() + PersistentTokenCacheAccessAspect.TIME_TO_LIVE);
tokenCacheMap.put(key, value);
}
}
}
5 changes: 5 additions & 0 deletions src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ protected Object[][] getContents() {
{"R_ConnectionURLNull", "The connection URL is null."},
{"R_connectionIsNotClosed", "The connection is not closed."},
{"R_invalidExceptionMessage", "Invalid exception message"},
{"R_invalidClientSecret", "AADSTS7000215: Invalid client secret provided"},
{"R_invalidCertFields",
"Error reading certificate, please verify the location of the certificate.signed fields invalid"},
tkyc marked this conversation as resolved.
Show resolved Hide resolved
{"R_invalidAADPasswordAuth",
"Failed to authenticate the user {0} in Active Directory (Authentication=ActiveDirectoryPassword)"},
tkyc marked this conversation as resolved.
Show resolved Hide resolved
{"R_failedValidate", "failed to validate values in $0} "}, {"R_tableNotDropped", "table not dropped. "},
{"R_connectionReset", "Connection reset"}, {"R_unknownException", "Unknown exception"},
{"R_deadConnection", "Dead connection should be invalid"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import java.security.cert.CertificateFactory;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
tkyc marked this conversation as resolved.
Show resolved Hide resolved
import java.sql.SQLException;
import java.sql.Statement;
import java.text.MessageFormat;
import java.util.Collections;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -323,6 +325,66 @@ public void testAADServicePrincipalAuth() {
}
}

@Test
public void testAADServicePrincipalAuthFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidSecret() throws Exception {
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + applicationClientID + ";Password="
+ applicationKey;

String invalidSecretUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + applicationClientID + ";Password="
+ "invalidSecret";

// Should succeed on valid secret
try (Connection connection = DriverManager.getConnection(url)) {}

// Should fail on invalid secret
try (Connection connection = DriverManager.getConnection(invalidSecretUrl)) {
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidClientSecret")),
"Expected R_invalidClientSecret error.");
}
}

@Test
public void testActiveDirectoryPasswordFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidPassword() throws Exception {

// Should succeed on valid password
try (Connection conn = DriverManager.getConnection(adPasswordConnectionStr)) {}

// Should fail on invalid password
try (Connection conn = DriverManager.getConnection(adPasswordConnectionStr + ";password=invalidPassword;")) {
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
MessageFormat form = new MessageFormat(TestResource.getResource("R_invalidAADPasswordAuth"));
Object[] msgArgs = {azureUserName};
assertTrue(e.getMessage().contains(form.format(msgArgs)), "Expected R_invalidAADPasswordAuth error.");
}
}

@Test
public void testAADServicePrincipalCertAuthFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidPassword() throws Exception {
// Should succeed on valid cert field values
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipalCertificate + ";Username=" + applicationClientID
+ ";password=" + certificatePassword + ";clientCertificate=" + clientCertificate;

// Should fail on invalid cert field values
String invalidPasswordUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase
+ ";authentication=" + SqlAuthentication.ActiveDirectoryServicePrincipalCertificate + ";Username="
+ applicationClientID + ";password=invalidPassword;clientCertificate=" + clientCertificate;

try (Connection conn = DriverManager.getConnection(url)) {}

try (Connection conn = DriverManager.getConnection(invalidPasswordUrl)) {
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidCertFields")),
"Expected R_invalidCertFields error.");
}
}

/**
* Test invalid connection property combinations when using AAD Service Principal Authentication.
*/
Expand Down