diff --git a/azure-pipelines.yml b/azure-pipelines.yml index a2f9116b2..ab090c4f6 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -38,8 +38,8 @@ jobs: displayName: 'Maven build jre14' inputs: mavenPomFile: 'pom.xml' - goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre14 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath) --DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer)' + goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre14 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath) +-DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer) -DtenantID=$(tenantID)' testResultsFiles: '**/TEST-*.xml' testRunTitle: 'Maven build jre14' javaHomeOption: Path @@ -49,7 +49,7 @@ jobs: inputs: mavenPomFile: 'pom.xml' goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre11 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath) --DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer)' +-DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer) -DtenantID=$(tenantID)' testResultsFiles: '**/TEST-*.xml' testRunTitle: 'Maven build jre11' javaHomeOption: Path @@ -58,8 +58,8 @@ jobs: displayName: 'Maven build jre8' inputs: mavenPomFile: 'pom.xml' - goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre8 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath) --DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer)' + goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre8 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath) +-DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer) -DtenantID=$(tenantID)' testResultsFiles: '**/TEST-*.xml' testRunTitle: 'Maven build jre8' javaHomeOption: Path diff --git a/build.gradle b/build.gradle index 7c0c08bbb..aef00225f 100644 --- a/build.gradle +++ b/build.gradle @@ -110,14 +110,12 @@ repositories { dependencies { compile 'org.osgi:org.osgi.core:6.0.0', 'org.osgi:org.osgi.compendium:5.0.0' - compileOnly 'com.microsoft.azure:azure-keyvault:1.2.2', - 'com.microsoft.azure:azure-keyvault-webkey:1.2.1', - 'com.microsoft.rest:client-runtime:1.7.0', - 'com.microsoft.azure:adal4j:1.6.4', + compileOnly 'com.azure:azure-security-keyvault-keys:4.2.0', + 'com.azure:azure-identity:1.1.0', 'org.antlr:antlr4-runtime:4.7.2', 'com.google.code.gson:gson:2.8.6', - 'org.bouncycastle:bcprov-jdk15on:1.64', - 'org.bouncycastle:bcpkix-jdk15on:1.64' + 'org.bouncycastle:bcprov-jdk15on:1.65', + 'org.bouncycastle:bcpkix-jdk15on:1.65' testCompile 'org.junit.platform:junit-platform-console:1.5.2', 'org.junit.platform:junit-platform-commons:1.5.2', 'org.junit.platform:junit-platform-engine:1.5.2', @@ -127,15 +125,14 @@ dependencies { 'org.junit.jupiter:junit-jupiter-api:5.6.0', 'org.junit.jupiter:junit-jupiter-engine:5.6.0', 'org.junit.jupiter:junit-jupiter-params:5.6.0', - 'com.zaxxer:HikariCP:3.4.1', - 'org.apache.commons:commons-dbcp2:2.6.0', + 'com.zaxxer:HikariCP:3.4.2', + 'org.apache.commons:commons-dbcp2:2.7.0', 'org.slf4j:slf4j-nop:1.7.29', 'org.antlr:antlr4-runtime:4.7.2', 'org.eclipse.gemini.blueprint:gemini-blueprint-mock:2.1.0.RELEASE', 'com.google.code.gson:gson:2.8.6', - 'org.bouncycastle:bcprov-jdk15on:1.64', - 'com.microsoft.azure:adal4j:1.6.4', - 'com.microsoft.azure:azure-keyvault:1.2.2', - 'com.microsoft.azure:azure-keyvault-webkey:1.2.1', + 'org.bouncycastle:bcprov-jdk15on:1.65', + 'com.azure:azure-security-keyvault-keys:4.2.0', + 'com.azure:azure-identity:1.1.0', 'com.h2database:h2:1.4.200' } diff --git a/pom.xml b/pom.xml index feb5cbc75..fef1f5b8d 100644 --- a/pom.xml +++ b/pom.xml @@ -60,9 +60,8 @@ -preview - 1.2.4 - 1.6.5 - 1.7.4 + 4.1.4 + 1.0.7 6.0.0 5.0.0 4.7.2 @@ -86,22 +85,14 @@ - com.microsoft.azure - azure-keyvault + com.azure + azure-security-keyvault-keys ${azure.keyvault.version} - true - com.microsoft.azure - adal4j - ${azure.adal4j.version} - true - - - com.microsoft.rest - client-runtime - ${rest.client.version} - true + com.azure + azure-identity + ${azure.identity.version} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java index ca4f21f68..0e8068780 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java @@ -5,86 +5,117 @@ package com.microsoft.sqlserver.jdbc; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -import com.microsoft.aad.adal4j.AuthenticationContext; -import com.microsoft.aad.adal4j.AuthenticationResult; -import com.microsoft.aad.adal4j.ClientCredential; -import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials; - +import com.azure.core.annotation.Immutable; +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenCredential; +import com.azure.core.credential.TokenRequestContext; +import com.azure.core.util.logging.ClientLogger; +import com.microsoft.aad.msal4j.ClientCredentialFactory; +import com.microsoft.aad.msal4j.ClientCredentialParameters; +import com.microsoft.aad.msal4j.ConfidentialClientApplication; +import com.microsoft.aad.msal4j.IAuthenticationResult; +import com.microsoft.aad.msal4j.IClientCredential; +import com.microsoft.aad.msal4j.SilentParameters; +import java.net.MalformedURLException; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.HashSet; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import reactor.core.publisher.Mono; /** - * - * An implementation of ServiceClientCredentials that supports automatic bearer token refresh. - * + * An AAD credential that acquires a token with a client secret for an AAD application. */ -class KeyVaultCredential extends KeyVaultCredentials { - - SQLServerKeyVaultAuthenticationCallback authenticationCallback = null; - String clientId = null; - String clientKey = null; - String accessToken = null; +@Immutable +class KeyVaultCredential implements TokenCredential { + private final ClientLogger logger = new ClientLogger(KeyVaultCredential.class); + private final String clientId; + private final String clientSecret; + private String authorization; + private ConfidentialClientApplication confidentialClientApplication; - KeyVaultCredential(String clientId) throws SQLServerException { + /** + * Creates a KeyVaultCredential with the given identity client options. + * + * @param clientId the client ID of the application + * @param clientSecret the secret value of the AAD application. + */ + KeyVaultCredential(String clientId, String clientSecret) { + Objects.requireNonNull(clientSecret, "'clientSecret' cannot be null."); + Objects.requireNonNull(clientSecret, "'clientId' cannot be null."); this.clientId = clientId; + this.clientSecret = clientSecret; } - KeyVaultCredential() {} - - KeyVaultCredential(String clientId, String clientKey) { - this.clientId = clientId; - this.clientKey = clientKey; + @Override + public Mono getToken(TokenRequestContext request) { + return authenticateWithConfidentialClientCache(request) + .onErrorResume(t -> Mono.empty()) + .switchIfEmpty(Mono.defer(() -> authenticateWithConfidentialClient(request))); } - KeyVaultCredential(SQLServerKeyVaultAuthenticationCallback authenticationCallback) { - this.authenticationCallback = authenticationCallback; + public KeyVaultCredential setAuthorization(String authorization) { + if (this.authorization != null && this.authorization.equals(authorization)) { + return this; + } + this.authorization = authorization; + confidentialClientApplication = getConfidentialClientApplication(); + return this; } - public String doAuthenticate(String authorization, String resource, String scope) { - String accessToken = null; - if (null == authenticationCallback) { - if (null == clientKey) { - try { - SqlFedAuthToken token = SQLServerSecurityUtility.getMSIAuthToken(resource, clientId); - accessToken = (null != token) ? token.accessToken : null; - } catch (Exception e) { - throw new RuntimeException(e); - } - } else { - AuthenticationResult token = getAccessTokenFromClientCredentials(authorization, resource, clientId, - clientKey); - accessToken = token.getAccessToken(); - } + private ConfidentialClientApplication getConfidentialClientApplication() { + if (clientId == null) { + throw logger.logExceptionAsError(new IllegalArgumentException( + "A non-null value for client ID must be provided for user authentication.")); + } + + if (authorization == null) { + throw logger.logExceptionAsError(new IllegalArgumentException( + "A non-null value for authorization must be provided for user authentication.")); + } + + IClientCredential credential; + if (clientSecret != null) { + credential = ClientCredentialFactory.create(clientSecret); } else { - accessToken = authenticationCallback.getAccessToken(authorization, resource, scope); + throw logger.logExceptionAsError( + new IllegalArgumentException("Must provide client secret.")); + } + ConfidentialClientApplication.Builder applicationBuilder = + ConfidentialClientApplication.builder(clientId, credential); + try { + applicationBuilder = applicationBuilder.authority(authorization); + } catch (MalformedURLException e) { + throw logger.logExceptionAsWarning(new IllegalStateException(e)); } - return accessToken; + return applicationBuilder.build(); } - private static AuthenticationResult getAccessTokenFromClientCredentials(String authorization, String resource, - String clientId, String clientKey) { - AuthenticationContext context = null; - AuthenticationResult result = null; - ExecutorService service = null; - try { - service = Executors.newFixedThreadPool(1); - context = new AuthenticationContext(authorization, false, service); - ClientCredential credentials = new ClientCredential(clientId, clientKey); - Future future = context.acquireToken(resource, credentials, null); - result = future.get(); - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - if (null != service) { - service.shutdown(); + private Mono authenticateWithConfidentialClientCache(TokenRequestContext request) { + return Mono.fromFuture(() -> { + SilentParameters.SilentParametersBuilder parametersBuilder = SilentParameters + .builder(new HashSet<>(request.getScopes())); + try { + return confidentialClientApplication.acquireTokenSilently(parametersBuilder.build()); + } catch (MalformedURLException e) { + return getFailedCompletableFuture(logger.logExceptionAsError(new RuntimeException(e))); } - } + }).map(ar -> new AccessToken(ar.accessToken(), + OffsetDateTime.ofInstant(ar.expiresOnDate().toInstant(), ZoneOffset.UTC))) + .filter(t -> !t.isExpired()); + } - if (null == result) { - throw new RuntimeException("authentication result was null"); - } - return result; + private CompletableFuture getFailedCompletableFuture(Exception e) { + CompletableFuture completableFuture = new CompletableFuture<>(); + completableFuture.completeExceptionally(e); + return completableFuture; + } + + private Mono authenticateWithConfidentialClient(TokenRequestContext request) { + return Mono.fromFuture(() -> confidentialClientApplication + .acquireToken(ClientCredentialParameters.builder(new HashSet<>(request.getScopes())).build())) + .map(ar -> new AccessToken(ar.accessToken(), + OffsetDateTime.ofInstant(ar.expiresOnDate().toInstant(), ZoneOffset.UTC))); } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java new file mode 100644 index 000000000..45bec37b2 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java @@ -0,0 +1,102 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import com.azure.core.credential.TokenRequestContext; +import com.azure.core.http.HttpPipelineCallContext; +import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpResponse; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.util.CoreUtils; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import reactor.core.publisher.Mono; + +/** + * A policy that authenticates requests with Azure Key Vault service. + */ +class KeyVaultCustomCredentialPolicy implements HttpPipelinePolicy { + private static final String WWW_AUTHENTICATE = "WWW-Authenticate"; + private static final String BEARER_TOKEN_PREFIX = "Bearer "; + private static final String AUTHORIZATION = "Authorization"; + private final ScopeTokenCache cache; + private final KeyVaultCredential keyVaultCredential; + + /** + * Creates KeyVaultCustomCredentialPolicy. + * + * @param credential the token credential to authenticate the request + */ + public KeyVaultCustomCredentialPolicy(KeyVaultCredential credential) { + Objects.requireNonNull(credential, "'credential' cannot be null."); + this.cache = new ScopeTokenCache(credential::getToken); + this.keyVaultCredential = credential; + } + + /** + * Adds the required header to authenticate a request to Azure Key Vault service. + * + * @param context The request context + * @param next The next HTTP pipeline policy to process the {@code context's} request after this policy completes. + * @return A {@link Mono} representing the HTTP response that will arrive asynchronously. + */ + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + if ("http".equals(context.getHttpRequest().getUrl().getProtocol())) { + return Mono.error(new RuntimeException("Token credentials require a URL using the HTTPS protocol scheme")); + } + return next.clone().process() + // Ignore body + .doOnNext(HttpResponse::close) + .map(res -> res.getHeaderValue(WWW_AUTHENTICATE)) + .map(header -> extractChallenge(header, BEARER_TOKEN_PREFIX)) + .flatMap(map -> { + keyVaultCredential.setAuthorization(map.get("authorization")); + cache.setRequest(new TokenRequestContext().addScopes(map.get("resource") + "/.default")); + return cache.getToken(); + }) + .flatMap(token -> { + context.getHttpRequest().setHeader(AUTHORIZATION, BEARER_TOKEN_PREFIX + token.getToken()); + return next.process(); + }); + } + + /** + * Extracts the challenge off the authentication header. + * + * @param authenticateHeader The authentication header containing all the challenges. + * @param authChallengePrefix The authentication challenge name. + * @return a challenge map. + */ + private static Map extractChallenge(String authenticateHeader, String authChallengePrefix) { + if (!isValidChallenge(authenticateHeader, authChallengePrefix)) { + return null; + } + authenticateHeader = authenticateHeader.toLowerCase(Locale.ROOT).replace(authChallengePrefix.toLowerCase(Locale.ROOT), ""); + + String[] challenges = authenticateHeader.split(", "); + Map challengeMap = new HashMap<>(); + for (String pair : challenges) { + String[] keyValue = pair.split("="); + challengeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", "")); + } + return challengeMap; + } + + /** + * Verifies whether a challenge is bearer or not. + * + * @param authenticateHeader The authentication header containing all the challenges. + * @param authChallengePrefix The authentication challenge name. + * @return A boolean indicating tha challenge is valid or not. + */ + private static boolean isValidChallenge(String authenticateHeader, String authChallengePrefix) { + return (!CoreUtils.isNullOrEmpty(authenticateHeader) + && authenticateHeader.toLowerCase(Locale.ROOT).startsWith(authChallengePrefix.toLowerCase(Locale.ROOT))); + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java new file mode 100644 index 000000000..fc97daa45 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java @@ -0,0 +1,79 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import com.azure.core.http.HttpPipeline; +import com.azure.core.http.HttpPipelineBuilder; +import com.azure.core.http.policy.HttpLogOptions; +import com.azure.core.http.policy.HttpLoggingPolicy; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.http.policy.HttpPolicyProviders; +import com.azure.core.http.policy.RetryPolicy; +import com.azure.core.http.policy.UserAgentPolicy; +import com.azure.core.util.Configuration; +import com.azure.core.util.logging.ClientLogger; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +final class KeyVaultHttpPipelineBuilder { + public static final String APPLICATION_ID = "ms-sql-jdbc"; + private static final String SDK_NAME = "azure-security-keyvault-keys"; + private static final String SDK_VERSION = "4.2.0"; + + private final ClientLogger logger = new ClientLogger(KeyVaultHttpPipelineBuilder.class); + + private final List policies; + private KeyVaultCredential credential; + private HttpLogOptions httpLogOptions; + private final RetryPolicy retryPolicy; + + /** + * The constructor with defaults. + */ + public KeyVaultHttpPipelineBuilder() { + retryPolicy = new RetryPolicy(); + httpLogOptions = new HttpLogOptions(); + policies = new ArrayList<>(); + } + + public HttpPipeline buildPipeline() { + Configuration buildConfiguration = Configuration.getGlobalConfiguration().clone(); + + if (credential == null) { + throw logger.logExceptionAsError(new IllegalStateException("Token Credential should be specified.")); + } + + // Closest to API goes first, closest to wire goes last. + final List policies = new ArrayList<>(); + + policies.add(new UserAgentPolicy(APPLICATION_ID, SDK_NAME, SDK_VERSION, buildConfiguration)); + HttpPolicyProviders.addBeforeRetryPolicies(policies); + policies.add(retryPolicy); + policies.add(new KeyVaultCustomCredentialPolicy(credential)); + policies.addAll(this.policies); + HttpPolicyProviders.addAfterRetryPolicies(policies); + policies.add(new HttpLoggingPolicy(httpLogOptions)); + + return new HttpPipelineBuilder() + .policies(policies.toArray(new HttpPipelinePolicy[0])) + .build(); + } + + /** + * Sets the credential to use when authenticating HTTP requests. + * + * @param credential The credential to use for authenticating HTTP requests. + * @return the updated KVHttpPipelineBuilder object. + * @throws NullPointerException if {@code credential} is {@code null}. + */ + public KeyVaultHttpPipelineBuilder credential(KeyVaultCredential credential) { + Objects.requireNonNull(credential); + this.credential = credential; + return this; + } +} + diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java index 90fbf756f..a575fc166 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java @@ -7,6 +7,7 @@ import static java.nio.charset.StandardCharsets.UTF_16LE; +import com.azure.core.http.HttpPipeline; import java.io.FileInputStream; import java.io.IOException; import java.net.URI; @@ -19,28 +20,31 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Objects; import java.util.Properties; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; -import com.microsoft.azure.AzureResponseBuilder; -import com.microsoft.azure.keyvault.KeyVaultClient; -import com.microsoft.azure.keyvault.models.KeyBundle; -import com.microsoft.azure.keyvault.models.KeyOperationResult; -import com.microsoft.azure.keyvault.models.KeyVerifyResult; -import com.microsoft.azure.keyvault.webkey.JsonWebKeyEncryptionAlgorithm; -import com.microsoft.azure.keyvault.webkey.JsonWebKeySignatureAlgorithm; -import com.microsoft.azure.serializer.AzureJacksonAdapter; -import com.microsoft.rest.RestClient; - -import okhttp3.OkHttpClient; -import retrofit2.Retrofit; - +import com.azure.core.credential.TokenCredential; +import com.azure.identity.ManagedIdentityCredentialBuilder; +import com.azure.security.keyvault.keys.KeyClient; +import com.azure.security.keyvault.keys.KeyClientBuilder; +import com.azure.security.keyvault.keys.cryptography.CryptographyClient; +import com.azure.security.keyvault.keys.cryptography.CryptographyClientBuilder; +import com.azure.security.keyvault.keys.cryptography.models.KeyWrapAlgorithm; +import com.azure.security.keyvault.keys.cryptography.models.SignResult; +import com.azure.security.keyvault.keys.cryptography.models.SignatureAlgorithm; +import com.azure.security.keyvault.keys.cryptography.models.UnwrapResult; +import com.azure.security.keyvault.keys.cryptography.models.VerifyResult; +import com.azure.security.keyvault.keys.cryptography.models.WrapResult; +import com.azure.security.keyvault.keys.models.KeyType; +import com.azure.security.keyvault.keys.models.KeyVaultKey; /** * Provides implementation similar to certificate store provider. A CEK encrypted with certificate store provider should * be decryptable by this provider and vice versa. - * + * * Envelope Format for the encrypted column encryption key version + keyPathLength + ciphertextLength + keyPath + * ciphertext + signature version: A single byte indicating the format version. keyPathLength: Length of the keyPath. * ciphertextLength: ciphertext length keyPath: keyPath used to encrypt the column encryption key. This is only used for @@ -49,668 +53,675 @@ */ public class SQLServerColumnEncryptionAzureKeyVaultProvider extends SQLServerColumnEncryptionKeyStoreProvider { - private final static java.util.logging.Logger akvLogger = java.util.logging.Logger - .getLogger("com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider"); - /** - * Column Encryption Key Store Provider string - */ - String name = "AZURE_KEY_VAULT"; - - private final String baseUrl = "https://{vaultBaseUrl}"; - - private static final String MSSQL_JDBC_PROPERTIES = "mssql-jdbc.properties"; - private static final String AKV_TRUSTED_ENDPOINTS_KEYWORD = "AKVTrustedEndpoints"; - private static final List akvTrustedEndpoints; - static { - akvTrustedEndpoints = getTrustedEndpoints(); - } - private final String rsaEncryptionAlgorithmWithOAEPForAKV = "RSA-OAEP"; - - /** - * Algorithm version - */ - private final byte[] firstVersion = new byte[] {0x01}; - - private KeyVaultClient keyVaultClient; - - private KeyVaultCredential credentials; - - public void setName(String name) { - this.name = name; - } - - public String getName() { - return this.name; - } - - /** - * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a client id and client key to authenticate to - * AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. - * - * @param clientId - * Identifier of the client requesting the token. - * @param clientKey - * Key of the client requesting the token. - * @throws SQLServerException - * when an error occurs - */ - public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId, String clientKey) throws SQLServerException { - credentials = new KeyVaultCredential(clientId, clientKey); - keyVaultClient = new KeyVaultClient(credentials); - } - - /** - * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a callback function to authenticate to AAD and - * an executor service.. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. - * - * This constructor is present to maintain backwards compatibility with 6.0 version of the driver. Deprecated for - * removal in next stable release. - * - * @param authenticationCallback - * - Callback function used for authenticating to AAD. - * @param executorService - * - The ExecutorService, previously used to create the keyVaultClient, but not in use anymore. - This - * parameter can be passed as 'null' - * @throws SQLServerException - * when an error occurs - */ - @Deprecated - public SQLServerColumnEncryptionAzureKeyVaultProvider( - SQLServerKeyVaultAuthenticationCallback authenticationCallback, - ExecutorService executorService) throws SQLServerException { - this(authenticationCallback); - } - - /** - * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a callback function to authenticate to AAD. This - * is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. - * - * @param authenticationCallback - * - Callback function used for authenticating to AAD. - * @throws SQLServerException - * when an error occurs - */ - public SQLServerColumnEncryptionAzureKeyVaultProvider( - SQLServerKeyVaultAuthenticationCallback authenticationCallback) throws SQLServerException { - if (null == authenticationCallback) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue")); - Object[] msgArgs1 = {"SQLServerKeyVaultAuthenticationCallback"}; - throw new SQLServerException(form.format(msgArgs1), null); + private final static java.util.logging.Logger akvLogger = java.util.logging.Logger + .getLogger("com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider"); + private HttpPipeline keyVaultPipeline; + private KeyVaultCredential keyVaultCredential; + /** + * Column Encryption Key Store Provider string + */ + String name = "AZURE_KEY_VAULT"; + + private static final String MSSQL_JDBC_PROPERTIES = "mssql-jdbc.properties"; + private static final String AKV_TRUSTED_ENDPOINTS_KEYWORD = "AKVTrustedEndpoints"; + private static final String RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV = "RSA-OAEP"; + + private static final List akvTrustedEndpoints; + /** + * Algorithm version + */ + private final byte[] firstVersion = new byte[] {0x01}; + + private Map cachedKeyClients = new ConcurrentHashMap<>(); + private Map cachedCryptographyClients = new ConcurrentHashMap<>(); + private TokenCredential credential; + + static { + akvTrustedEndpoints = getTrustedEndpoints(); } - credentials = new KeyVaultCredential(authenticationCallback); - RestClient restClient = new RestClient.Builder(new OkHttpClient.Builder(), new Retrofit.Builder()) - .withBaseUrl(baseUrl).withCredentials(credentials).withSerializerAdapter(new AzureJacksonAdapter()) - .withResponseBuilderFactory(new AzureResponseBuilder.Factory()).build(); - keyVaultClient = new KeyVaultClient(restClient); - } - - /** - * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by - * KeyVaultClient at runtime to authenticate to Azure Key Vault. - * - * @throws SQLServerException - * when an error occurs - */ - SQLServerColumnEncryptionAzureKeyVaultProvider() throws SQLServerException { - credentials = new KeyVaultCredential(); - keyVaultClient = new KeyVaultClient(credentials); - } - - /** - * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by - * KeyVaultClient at runtime to authenticate to Azure Key Vault. - * - * @param clientId - * Identifier of the client requesting the token. - * - * @throws SQLServerException - * when an error occurs - */ - SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId) throws SQLServerException { - credentials = new KeyVaultCredential(clientId); - keyVaultClient = new KeyVaultClient(credentials); - } - - /** - * Decrypts an encrypted CEK with RSA encryption algorithm using the asymmetric key specified by the key path - * - * @param masterKeyPath - * - Complete path of an asymmetric key in AKV - * @param encryptionAlgorithm - * - Asymmetric Key Encryption Algorithm - * @param encryptedColumnEncryptionKey - * - Encrypted Column Encryption Key - * @return Plain text column encryption key - */ - @Override - public byte[] decryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm, - byte[] encryptedColumnEncryptionKey) throws SQLServerException { - - // Validate the input parameters - this.ValidateNonEmptyAKVPath(masterKeyPath); - - if (null == encryptedColumnEncryptionKey) { - throw new SQLServerException(SQLServerException.getErrString("R_NullEncryptedColumnEncryptionKey"), null); + + public void setName(String name) { + this.name = name; } - if (0 == encryptedColumnEncryptionKey.length) { - throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedColumnEncryptionKey"), null); + public String getName() { + return this.name; } - // Validate encryptionAlgorithm - encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm); - - // Validate whether the key is RSA one or not and then get the key size - int keySizeInBytes = getAKVKeySize(masterKeyPath); - - // Validate and decrypt the EncryptedColumnEncryptionKey - // Format is - // version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature - // - // keyPath is present in the encrypted column encryption key for identifying the original source of the - // asymmetric key pair and - // we will not validate it against the data contained in the CMK metadata (masterKeyPath). - - // Validate the version byte - if (encryptedColumnEncryptionKey[0] != firstVersion[0]) { - MessageFormat form = new MessageFormat( - SQLServerException.getErrString("R_InvalidEcryptionAlgorithmVersion")); - Object[] msgArgs = {String.format("%02X ", encryptedColumnEncryptionKey[0]), - String.format("%02X ", firstVersion[0])}; - throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId, String clientKey) + throws SQLServerException { + if (clientId == null || clientId.isEmpty()) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue")); + Object[] msgArgs1 = {"Client ID"}; + throw new SQLServerException(form.format(msgArgs1), null); + } + if (clientKey == null || clientKey.isEmpty()) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue")); + Object[] msgArgs1 = {"Client Key"}; + throw new SQLServerException(form.format(msgArgs1), null); + } + + keyVaultCredential = new KeyVaultCredential(clientId, clientKey); + keyVaultPipeline = new KeyVaultHttpPipelineBuilder().credential(keyVaultCredential) + .buildPipeline(); } - // Get key path length - int currentIndex = firstVersion.length; - short keyPathLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex); - // We just read 2 bytes - currentIndex += 2; - - // Get ciphertext length - short cipherTextLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex); - currentIndex += 2; - - // Skip KeyPath - // KeyPath exists only for troubleshooting purposes and doesnt need validation. - currentIndex += keyPathLength; - - // validate the ciphertext length - if (cipherTextLength != keySizeInBytes) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyLengthError")); - Object[] msgArgs = {cipherTextLength, keySizeInBytes, masterKeyPath}; - throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + /** + * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by + * KeyVaultClient at runtime to authenticate to Azure Key Vault. + */ + SQLServerColumnEncryptionAzureKeyVaultProvider() throws SQLServerException { + createKeyvaultClients(new ManagedIdentityCredentialBuilder().build()); } - // Validate the signature length - int signatureLength = encryptedColumnEncryptionKey.length - currentIndex - cipherTextLength; + /** + * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by + * KeyVaultClient at runtime to authenticate to Azure Key Vault. + * + * @param clientId Identifier of the client requesting the token. + */ + SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId) throws SQLServerException { + if (clientId == null || clientId.isEmpty()) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue")); + Object[] msgArgs1 = {"Client ID"}; + throw new SQLServerException(form.format(msgArgs1), null); + } + createKeyvaultClients(new ManagedIdentityCredentialBuilder().clientId(clientId).build()); + } - if (signatureLength != keySizeInBytes) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVSignatureLengthError")); - Object[] msgArgs = {signatureLength, keySizeInBytes, masterKeyPath}; - throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + /** + * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider using the provided TokenCredential to + * authenticate to AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. + * + * @param tokenCredential The TokenCredential to use to authenticate to Azure Key Vault. + */ + public SQLServerColumnEncryptionAzureKeyVaultProvider(TokenCredential tokenCredential) + throws SQLServerException { + createKeyvaultClients(tokenCredential); } - // Get ciphertext - byte[] cipherText = new byte[cipherTextLength]; - System.arraycopy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength); - currentIndex += cipherTextLength; + private void createKeyvaultClients(TokenCredential credential) throws SQLServerException { + this.credential = Objects.requireNonNull(credential); + } - // Get signature - byte[] signature = new byte[signatureLength]; - System.arraycopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength); + /** + * Decrypts an encrypted CEK with RSA encryption algorithm using the asymmetric key specified by the key path + * + * @param masterKeyPath - Complete path of an asymmetric key in AKV + * @param encryptionAlgorithm - Asymmetric Key Encryption Algorithm + * @param encryptedColumnEncryptionKey - Encrypted Column Encryption Key + * @return Plain text column encryption key + */ + @Override public byte[] decryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm, + byte[] encryptedColumnEncryptionKey) throws SQLServerException { - // Compute the hash to validate the signature - byte[] hash = new byte[encryptedColumnEncryptionKey.length - signature.length]; + // Validate the input parameters + this.ValidateNonEmptyAKVPath(masterKeyPath); - System.arraycopy(encryptedColumnEncryptionKey, 0, hash, 0, - encryptedColumnEncryptionKey.length - signature.length); + if (null == encryptedColumnEncryptionKey) { + throw new SQLServerException( + SQLServerException.getErrString("R_NullEncryptedColumnEncryptionKey"), null); + } - MessageDigest md = null; - try { - md = MessageDigest.getInstance("SHA-256"); - } catch (NoSuchAlgorithmException e) { - throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); - } - md.update(hash); - byte dataToVerify[] = md.digest(); + if (0 == encryptedColumnEncryptionKey.length) { + throw new SQLServerException( + SQLServerException.getErrString("R_EmptyEncryptedColumnEncryptionKey"), null); + } - if (null == dataToVerify) { - throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null); - } + // Validate encryptionAlgorithm + KeyWrapAlgorithm _encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm); + + // Validate whether the key is RSA one or not and then get the key size + int keySizeInBytes = getAKVKeySize(masterKeyPath); + + // Validate and decrypt the EncryptedColumnEncryptionKey + // Format is + // version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature + // + // keyPath is present in the encrypted column encryption key for identifying the original source of the + // asymmetric key pair and + // we will not validate it against the data contained in the CMK metadata (masterKeyPath). + + // Validate the version byte + if (encryptedColumnEncryptionKey[0] != firstVersion[0]) { + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_InvalidEcryptionAlgorithmVersion")); + Object[] msgArgs = {String.format("%02X ", encryptedColumnEncryptionKey[0]), + String.format("%02X ", firstVersion[0])}; + throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + } - // Validate the signature - if (!AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath)) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_CEKSignatureNotMatchCMK")); - Object[] msgArgs = {masterKeyPath}; - throw new SQLServerException(this, form.format(msgArgs), null, 0, false); - } + // Get key path length + int currentIndex = firstVersion.length; + short keyPathLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex); + // We just read 2 bytes + currentIndex += 2; + + // Get ciphertext length + short cipherTextLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex); + currentIndex += 2; + + // Skip KeyPath + // KeyPath exists only for troubleshooting purposes and doesnt need validation. + currentIndex += keyPathLength; + + // validate the ciphertext length + if (cipherTextLength != keySizeInBytes) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyLengthError")); + Object[] msgArgs = {cipherTextLength, keySizeInBytes, masterKeyPath}; + throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + } - // Decrypt the CEK - byte[] decryptedCEK = this.AzureKeyVaultUnWrap(masterKeyPath, encryptionAlgorithm, cipherText); + // Validate the signature length + int signatureLength = encryptedColumnEncryptionKey.length - currentIndex - cipherTextLength; - return decryptedCEK; - } + if (signatureLength != keySizeInBytes) { + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_AKVSignatureLengthError")); + Object[] msgArgs = {signatureLength, keySizeInBytes, masterKeyPath}; + throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + } - private short convertTwoBytesToShort(byte[] input, int index) throws SQLServerException { + // Get ciphertext + byte[] cipherText = new byte[cipherTextLength]; + System.arraycopy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength); + currentIndex += cipherTextLength; - short shortVal; - if (index + 1 >= input.length) { - throw new SQLServerException(null, SQLServerException.getErrString("R_ByteToShortConversion"), null, 0, - false); - } - ByteBuffer byteBuffer = ByteBuffer.allocate(2); - byteBuffer.order(ByteOrder.LITTLE_ENDIAN); - byteBuffer.put(input[index]); - byteBuffer.put(input[index + 1]); - shortVal = byteBuffer.getShort(0); - return shortVal; - - } - - /** - * Encrypts CEK with RSA encryption algorithm using the asymmetric key specified by the key path. - * - * @param masterKeyPath - * - Complete path of an asymmetric key in AKV - * @param encryptionAlgorithm - * - Asymmetric Key Encryption Algorithm - * @param columnEncryptionKey - * - Plain text column encryption key - * @return Encrypted column encryption key - */ - @Override - public byte[] encryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm, - byte[] columnEncryptionKey) throws SQLServerException { - - // Validate the input parameters - this.ValidateNonEmptyAKVPath(masterKeyPath); - - if (null == columnEncryptionKey) { - throw new SQLServerException(SQLServerException.getErrString("R_NullColumnEncryptionKey"), null); + // Get signature + byte[] signature = new byte[signatureLength]; + System.arraycopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength); + + // Compute the hash to validate the signature + byte[] hash = new byte[encryptedColumnEncryptionKey.length - signature.length]; + + System.arraycopy(encryptedColumnEncryptionKey, 0, hash, 0, + encryptedColumnEncryptionKey.length - signature.length); + + MessageDigest md = null; + try { + md = MessageDigest.getInstance("SHA-256"); + } catch (NoSuchAlgorithmException e) { + throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); + } + md.update(hash); + byte dataToVerify[] = md.digest(); + + if (null == dataToVerify) { + throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null); + } + + // Validate the signature + if (!AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath)) { + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_CEKSignatureNotMatchCMK")); + Object[] msgArgs = {masterKeyPath}; + throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + } + + // Decrypt the CEK + byte[] decryptedCEK = this.AzureKeyVaultUnWrap(masterKeyPath, _encryptionAlgorithm, cipherText); + + return decryptedCEK; } - if (0 == columnEncryptionKey.length) { - throw new SQLServerException(SQLServerException.getErrString("R_EmptyCEK"), null); + private short convertTwoBytesToShort(byte[] input, int index) throws SQLServerException { + + short shortVal; + if (index + 1 >= input.length) { + throw new SQLServerException(null, SQLServerException.getErrString("R_ByteToShortConversion"), + null, 0, false); + } + ByteBuffer byteBuffer = ByteBuffer.allocate(2); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + byteBuffer.put(input[index]); + byteBuffer.put(input[index + 1]); + shortVal = byteBuffer.getShort(0); + return shortVal; + } - // Validate encryptionAlgorithm - encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm); + /** + * Encrypts CEK with RSA encryption algorithm using the asymmetric key specified by the key path. + * + * @param masterKeyPath - Complete path of an asymmetric key in AKV + * @param encryptionAlgorithm - Asymmetric Key Encryption Algorithm + * @param columnEncryptionKey - Plain text column encryption key + * @return Encrypted column encryption key + */ + @Override public byte[] encryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm, + byte[] columnEncryptionKey) throws SQLServerException { + + // Validate the input parameters + this.ValidateNonEmptyAKVPath(masterKeyPath); - // Validate whether the key is RSA one or not and then get the key size - int keySizeInBytes = getAKVKeySize(masterKeyPath); + if (null == columnEncryptionKey) { + throw new SQLServerException(SQLServerException.getErrString("R_NullColumnEncryptionKey"), + null); + } - // Construct the encryptedColumnEncryptionKey - // Format is - // version + keyPathLength + ciphertextLength + ciphertext + keyPath + signature - // - // We currently only support one version - byte[] version = new byte[] {firstVersion[0]}; + if (0 == columnEncryptionKey.length) { + throw new SQLServerException(SQLServerException.getErrString("R_EmptyCEK"), null); + } - // Get the Unicode encoded bytes of cultureinvariant lower case masterKeyPath - byte[] masterKeyPathBytes = masterKeyPath.toLowerCase(Locale.ENGLISH).getBytes(UTF_16LE); + // Validate encryptionAlgorithm + KeyWrapAlgorithm _encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm); - byte[] keyPathLength = new byte[2]; - keyPathLength[0] = (byte) (((short) masterKeyPathBytes.length) & 0xff); - keyPathLength[1] = (byte) (((short) masterKeyPathBytes.length) >> 8 & 0xff); + // Validate whether the key is RSA one or not and then get the key size + int keySizeInBytes = getAKVKeySize(masterKeyPath); - // Encrypt the plain text - byte[] cipherText = this.AzureKeyVaultWrap(masterKeyPath, encryptionAlgorithm, columnEncryptionKey); + // Construct the encryptedColumnEncryptionKey + // Format is + // version + keyPathLength + ciphertextLength + ciphertext + keyPath + signature + // + // We currently only support one version + byte[] version = new byte[] {firstVersion[0]}; - byte[] cipherTextLength = new byte[2]; - cipherTextLength[0] = (byte) (((short) cipherText.length) & 0xff); - cipherTextLength[1] = (byte) (((short) cipherText.length) >> 8 & 0xff); + // Get the Unicode encoded bytes of cultureinvariant lower case masterKeyPath + byte[] masterKeyPathBytes = masterKeyPath.toLowerCase(Locale.ENGLISH).getBytes(UTF_16LE); - if (cipherText.length != keySizeInBytes) { - throw new SQLServerException(SQLServerException.getErrString("R_CipherTextLengthNotMatchRSASize"), null); - } + byte[] keyPathLength = new byte[2]; + keyPathLength[0] = (byte) (((short) masterKeyPathBytes.length) & 0xff); + keyPathLength[1] = (byte) (((short) masterKeyPathBytes.length) >> 8 & 0xff); - // Compute hash - // SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext) - byte[] dataToHash = new byte[version.length + keyPathLength.length + cipherTextLength.length - + masterKeyPathBytes.length + cipherText.length]; - int destinationPosition = version.length; - System.arraycopy(version, 0, dataToHash, 0, version.length); + // Encrypt the plain text + byte[] cipherText = this.AzureKeyVaultWrap(masterKeyPath, _encryptionAlgorithm, columnEncryptionKey); - System.arraycopy(keyPathLength, 0, dataToHash, destinationPosition, keyPathLength.length); - destinationPosition += keyPathLength.length; + byte[] cipherTextLength = new byte[2]; + cipherTextLength[0] = (byte) (((short) cipherText.length) & 0xff); + cipherTextLength[1] = (byte) (((short) cipherText.length) >> 8 & 0xff); - System.arraycopy(cipherTextLength, 0, dataToHash, destinationPosition, cipherTextLength.length); - destinationPosition += cipherTextLength.length; + if (cipherText.length != keySizeInBytes) { + throw new SQLServerException( + SQLServerException.getErrString("R_CipherTextLengthNotMatchRSASize"), null); + } - System.arraycopy(masterKeyPathBytes, 0, dataToHash, destinationPosition, masterKeyPathBytes.length); - destinationPosition += masterKeyPathBytes.length; + // Compute hash + // SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext) + byte[] dataToHash = new byte[version.length + keyPathLength.length + cipherTextLength.length + + masterKeyPathBytes.length + cipherText.length]; + int destinationPosition = version.length; + System.arraycopy(version, 0, dataToHash, 0, version.length); - System.arraycopy(cipherText, 0, dataToHash, destinationPosition, cipherText.length); + System.arraycopy(keyPathLength, 0, dataToHash, destinationPosition, keyPathLength.length); + destinationPosition += keyPathLength.length; - MessageDigest md = null; - try { - md = MessageDigest.getInstance("SHA-256"); - } catch (NoSuchAlgorithmException e) { - throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); - } - md.update(dataToHash); - byte dataToSign[] = md.digest(); + System.arraycopy(cipherTextLength, 0, dataToHash, destinationPosition, cipherTextLength.length); + destinationPosition += cipherTextLength.length; - // Sign the hash - byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath); + System.arraycopy(masterKeyPathBytes, 0, dataToHash, destinationPosition, masterKeyPathBytes.length); + destinationPosition += masterKeyPathBytes.length; - if (signedHash.length != keySizeInBytes) { - throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null); - } + System.arraycopy(cipherText, 0, dataToHash, destinationPosition, cipherText.length); - if (!this.AzureKeyVaultVerifySignature(dataToSign, signedHash, masterKeyPath)) { - throw new SQLServerException(SQLServerException.getErrString("R_InvalidSignatureComputed"), null); - } + MessageDigest md = null; + try { + md = MessageDigest.getInstance("SHA-256"); + } catch (NoSuchAlgorithmException e) { + throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); + } + md.update(dataToHash); + byte dataToSign[] = md.digest(); + + // Sign the hash + byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath); + + if (signedHash.length != keySizeInBytes) { + throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null); + } + + if (!this.AzureKeyVaultVerifySignature(dataToSign, signedHash, masterKeyPath)) { + throw new SQLServerException(SQLServerException.getErrString("R_InvalidSignatureComputed"), + null); + } + + // Construct the encrypted column encryption key + // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature + int encryptedColumnEncryptionKeyLength = + version.length + cipherTextLength.length + keyPathLength.length + cipherText.length + + masterKeyPathBytes.length + signedHash.length; + byte[] encryptedColumnEncryptionKey = new byte[encryptedColumnEncryptionKeyLength]; - // Construct the encrypted column encryption key - // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature - int encryptedColumnEncryptionKeyLength = version.length + cipherTextLength.length + keyPathLength.length - + cipherText.length + masterKeyPathBytes.length + signedHash.length; - byte[] encryptedColumnEncryptionKey = new byte[encryptedColumnEncryptionKeyLength]; - - // Copy version byte - int currentIndex = 0; - System.arraycopy(version, 0, encryptedColumnEncryptionKey, currentIndex, version.length); - currentIndex += version.length; - - // Copy key path length - System.arraycopy(keyPathLength, 0, encryptedColumnEncryptionKey, currentIndex, keyPathLength.length); - currentIndex += keyPathLength.length; - - // Copy ciphertext length - System.arraycopy(cipherTextLength, 0, encryptedColumnEncryptionKey, currentIndex, cipherTextLength.length); - currentIndex += cipherTextLength.length; - - // Copy key path - System.arraycopy(masterKeyPathBytes, 0, encryptedColumnEncryptionKey, currentIndex, masterKeyPathBytes.length); - currentIndex += masterKeyPathBytes.length; - - // Copy ciphertext - System.arraycopy(cipherText, 0, encryptedColumnEncryptionKey, currentIndex, cipherText.length); - currentIndex += cipherText.length; - - // copy the signature - System.arraycopy(signedHash, 0, encryptedColumnEncryptionKey, currentIndex, signedHash.length); - - return encryptedColumnEncryptionKey; - } - - /** - * Validates that the encryption algorithm is RSA_OAEP and if it is not, then throws an exception. - * - * @param encryptionAlgorithm - * - Asymmetric key encryptio algorithm - * @return The encryption algorithm that is going to be used. - * @throws SQLServerException - */ - private String validateEncryptionAlgorithm(String encryptionAlgorithm) throws SQLServerException { - - if (null == encryptionAlgorithm) { - throw new SQLServerException(null, SQLServerException.getErrString("R_NullKeyEncryptionAlgorithm"), null, 0, - false); + // Copy version byte + int currentIndex = 0; + System.arraycopy(version, 0, encryptedColumnEncryptionKey, currentIndex, version.length); + currentIndex += version.length; + + // Copy key path length + System.arraycopy(keyPathLength, 0, encryptedColumnEncryptionKey, currentIndex, keyPathLength.length); + currentIndex += keyPathLength.length; + + // Copy ciphertext length + System.arraycopy(cipherTextLength, 0, encryptedColumnEncryptionKey, currentIndex, + cipherTextLength.length); + currentIndex += cipherTextLength.length; + + // Copy key path + System.arraycopy(masterKeyPathBytes, 0, encryptedColumnEncryptionKey, currentIndex, + masterKeyPathBytes.length); + currentIndex += masterKeyPathBytes.length; + + // Copy ciphertext + System.arraycopy(cipherText, 0, encryptedColumnEncryptionKey, currentIndex, cipherText.length); + currentIndex += cipherText.length; + + // copy the signature + System.arraycopy(signedHash, 0, encryptedColumnEncryptionKey, currentIndex, signedHash.length); + + return encryptedColumnEncryptionKey; } - // Transform to standard format (dash instead of underscore) to support both "RSA_OAEP" and "RSA-OAEP" - if ("RSA_OAEP".equalsIgnoreCase(encryptionAlgorithm)) { - encryptionAlgorithm = "RSA-OAEP"; + /** + * Validates that the encryption algorithm is RSA_OAEP and if it is not, then throws an exception. + * + * @param encryptionAlgorithm - Asymmetric key encryptio algorithm + * @return The encryption algorithm that is going to be used. + * @throws SQLServerException + */ + private KeyWrapAlgorithm validateEncryptionAlgorithm(String encryptionAlgorithm) throws SQLServerException { + + if (null == encryptionAlgorithm) { + throw new SQLServerException(null, + SQLServerException.getErrString("R_NullKeyEncryptionAlgorithm"), null, 0, false); + } + + // Transform to standard format (dash instead of underscore) to support enum lookup + if ("RSA_OAEP".equalsIgnoreCase(encryptionAlgorithm)) { + encryptionAlgorithm = RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV; + } + + if (!RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV.equalsIgnoreCase(encryptionAlgorithm.trim())) { + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_InvalidKeyEncryptionAlgorithm")); + Object[] msgArgs = {encryptionAlgorithm, RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV}; + throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + } + + return KeyWrapAlgorithm.fromString(encryptionAlgorithm); } - if (!rsaEncryptionAlgorithmWithOAEPForAKV.equalsIgnoreCase(encryptionAlgorithm.trim())) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidKeyEncryptionAlgorithm")); - Object[] msgArgs = {encryptionAlgorithm, rsaEncryptionAlgorithmWithOAEPForAKV}; - throw new SQLServerException(this, form.format(msgArgs), null, 0, false); + /** + * Checks if the Azure Key Vault key path is Empty or Null (and raises exception if they are). + * + * @param masterKeyPath + * @throws SQLServerException + */ + private void ValidateNonEmptyAKVPath(String masterKeyPath) throws SQLServerException { + // throw appropriate error if masterKeyPath is null or empty + if (null == masterKeyPath || masterKeyPath.trim().isEmpty()) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVPathNull")); + Object[] msgArgs = {masterKeyPath}; + throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + } else { + URI parsedUri = null; + try { + parsedUri = new URI(masterKeyPath); + + // A valid URI. + // Check if it is pointing to a trusted endpoint. + String host = parsedUri.getHost(); + if (null != host) { + host = host.toLowerCase(Locale.ENGLISH); + } + for (final String endpoint : akvTrustedEndpoints) { + if (null != host && host.endsWith(endpoint)) { + return; + } + } + } catch (URISyntaxException e) { + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_AKVURLInvalid")); + Object[] msgArgs = {masterKeyPath}; + throw new SQLServerException(form.format(msgArgs), null, 0, e); + } + + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_AKVMasterKeyPathInvalid")); + Object[] msgArgs = {masterKeyPath}; + throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + } } - return encryptionAlgorithm; - } - - /** - * Checks if the Azure Key Vault key path is Empty or Null (and raises exception if they are). - * - * @param masterKeyPath - * @throws SQLServerException - */ - private void ValidateNonEmptyAKVPath(String masterKeyPath) throws SQLServerException { - // throw appropriate error if masterKeyPath is null or empty - if (null == masterKeyPath || masterKeyPath.trim().isEmpty()) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVPathNull")); - Object[] msgArgs = {masterKeyPath}; - throw new SQLServerException(null, form.format(msgArgs), null, 0, false); - } else { - URI parsedUri = null; - try { - parsedUri = new URI(masterKeyPath); - - // A valid URI. - // Check if it is pointing to a trusted endpoint. - String host = parsedUri.getHost(); - if (null != host) { - host = host.toLowerCase(Locale.ENGLISH); - } - for (final String endpoint : akvTrustedEndpoints) { - if (null != host && host.endsWith(endpoint)) { - return; - } - } - } catch (URISyntaxException e) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVURLInvalid")); - Object[] msgArgs = {masterKeyPath}; - throw new SQLServerException(form.format(msgArgs), null, 0, e); - } + /** + * Encrypts the text using specified Azure Key Vault key. + * + * @param masterKeyPath - Azure Key Vault key url. + * @param encryptionAlgorithm - Encryption Algorithm. + * @param columnEncryptionKey - Plain text Column Encryption Key. + * @return Returns an encrypted blob or throws an exception if there are any errors. + * @throws SQLServerException + */ + private byte[] AzureKeyVaultWrap(String masterKeyPath, KeyWrapAlgorithm encryptionAlgorithm, + byte[] columnEncryptionKey) throws SQLServerException { + if (null == columnEncryptionKey) { + throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null); + } - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVMasterKeyPathInvalid")); - Object[] msgArgs = {masterKeyPath}; - throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath); + WrapResult wrappedKey = cryptoClient.wrapKey(KeyWrapAlgorithm.RSA_OAEP, columnEncryptionKey); + return wrappedKey.getEncryptedKey(); } - } - - /** - * Encrypts the text using specified Azure Key Vault key. - * - * @param masterKeyPath - * - Azure Key Vault key url. - * @param encryptionAlgorithm - * - Encryption Algorithm. - * @param columnEncryptionKey - * - Plain text Column Encryption Key. - * @return Returns an encrypted blob or throws an exception if there are any errors. - * @throws SQLServerException - */ - private byte[] AzureKeyVaultWrap(String masterKeyPath, String encryptionAlgorithm, - byte[] columnEncryptionKey) throws SQLServerException { - if (null == columnEncryptionKey) { - throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null); + + /** + * Encrypts the text using specified Azure Key Vault key. + * + * @param masterKeyPath - Azure Key Vault key url. + * @param encryptionAlgorithm - Encrypted Column Encryption Key. + * @param encryptedColumnEncryptionKey - Encrypted Column Encryption Key. + * @return Returns the decrypted plaintext Column Encryption Key or throws an exception if there are any errors. + * @throws SQLServerException + */ + private byte[] AzureKeyVaultUnWrap(String masterKeyPath, KeyWrapAlgorithm encryptionAlgorithm, + byte[] encryptedColumnEncryptionKey) throws SQLServerException { + if (null == encryptedColumnEncryptionKey) { + throw new SQLServerException(SQLServerException.getErrString("R_EncryptedCEKNull"), null); + } + + if (0 == encryptedColumnEncryptionKey.length) { + throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null); + } + + CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath); + + UnwrapResult unwrappedKey = cryptoClient.unwrapKey(encryptionAlgorithm, encryptedColumnEncryptionKey); + + return unwrappedKey.getKey(); } - JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm); - KeyOperationResult wrappedKey = keyVaultClient.wrapKey(masterKeyPath, jsonEncryptionAlgorithm, - columnEncryptionKey); - - return wrappedKey.result(); - } - - /** - * Encrypts the text using specified Azure Key Vault key. - * - * @param masterKeyPath - * - Azure Key Vault key url. - * @param encryptionAlgorithm - * - Encrypted Column Encryption Key. - * @param encryptedColumnEncryptionKey - * - Encrypted Column Encryption Key. - * @return Returns the decrypted plaintext Column Encryption Key or throws an exception if there are any errors. - * @throws SQLServerException - */ - private byte[] AzureKeyVaultUnWrap(String masterKeyPath, String encryptionAlgorithm, - byte[] encryptedColumnEncryptionKey) throws SQLServerException { - if (null == encryptedColumnEncryptionKey) { - throw new SQLServerException(SQLServerException.getErrString("R_EncryptedCEKNull"), null); + private CryptographyClient getCryptographyClient(String masterKeyPath) throws SQLServerException { + if (this.cachedCryptographyClients.containsKey(masterKeyPath)) { + return cachedCryptographyClients.get(masterKeyPath); + } + + KeyVaultKey retrievedKey = getKeyVaultKey(masterKeyPath); + + CryptographyClient cryptoClient; + if (credential != null) { + cryptoClient = new CryptographyClientBuilder().credential(credential) + .keyIdentifier(retrievedKey.getId()).buildClient(); + } else { + cryptoClient = new CryptographyClientBuilder().pipeline(keyVaultPipeline) + .keyIdentifier(retrievedKey.getId()).buildClient(); + } + cachedCryptographyClients.putIfAbsent(masterKeyPath, cryptoClient); + return cachedCryptographyClients.get(masterKeyPath); } - if (0 == encryptedColumnEncryptionKey.length) { - throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null); + /** + * Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL. + * + * @param dataToSign - Text to sign. + * @param masterKeyPath - Azure Key Vault key url. + * @return Signature + * @throws SQLServerException + */ + private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign, String masterKeyPath) throws SQLServerException { + assert ((null != dataToSign) && (0 != dataToSign.length)); + + CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath); + SignResult signedData = cryptoClient.sign(SignatureAlgorithm.RS256, dataToSign); + return signedData.getSignature(); } - JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm); - KeyOperationResult unwrappedKey = keyVaultClient.unwrapKey(masterKeyPath, jsonEncryptionAlgorithm, - encryptedColumnEncryptionKey); - - return unwrappedKey.result(); - } - - /** - * Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL. - * - * @param dataToSign - * - Text to sign. - * @param masterKeyPath - * - Azure Key Vault key url. - * @return Signature - * @throws SQLServerException - */ - private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign, String masterKeyPath) throws SQLServerException { - assert ((null != dataToSign) && (0 != dataToSign.length)); - - KeyOperationResult signedData = keyVaultClient.sign(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, - dataToSign); - - return signedData.result(); - } - - /** - * Verifies the given RSA PKCSv1.5 signature. - * - * @param dataToVerify - * @param signature - * @param masterKeyPath - * - Azure Key Vault key url. - * @return true if signature is valid, false if it is not valid - * @throws SQLServerException - */ - private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify, byte[] signature, - String masterKeyPath) throws SQLServerException { - assert ((null != dataToVerify) && (0 != dataToVerify.length)); - assert ((null != signature) && (0 != signature.length)); - - KeyVerifyResult valid = keyVaultClient.verify(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify, - signature); - - return valid.value(); - } - - /** - * Returns the public Key size in bytes. - * - * @param masterKeyPath - * - Azure Key Vault Key path - * @return Key size in bytes - * @throws SQLServerException - * when an error occurs - */ - private int getAKVKeySize(String masterKeyPath) throws SQLServerException { - KeyBundle retrievedKey = keyVaultClient.getKey(masterKeyPath); - - if (null == retrievedKey) { - String[] keyTokens = masterKeyPath.split("/"); - - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyNotFound")); - Object[] msgArgs = {keyTokens[keyTokens.length - 1]}; - throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + /** + * Verifies the given RSA PKCSv1.5 signature. + * + * @param dataToVerify + * @param signature + * @param masterKeyPath - Azure Key Vault key url. + * @return true if signature is valid, false if it is not valid + * @throws SQLServerException + */ + private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify, byte[] signature, String masterKeyPath) + throws SQLServerException { + assert ((null != dataToVerify) && (0 != dataToVerify.length)); + assert ((null != signature) && (0 != signature.length)); + + CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath); + VerifyResult valid = cryptoClient.verify(SignatureAlgorithm.RS256, dataToVerify, signature); + + return valid.isValid(); } - if (!"RSA".equalsIgnoreCase(retrievedKey.key().kty().toString()) - && !"RSA-HSM".equalsIgnoreCase(retrievedKey.key().kty().toString())) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey")); - Object[] msgArgs = {retrievedKey.key().kty().toString()}; - throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + /** + * Returns the public Key size in bytes. + * + * @param masterKeyPath - Azure Key Vault Key path + * @return Key size in bytes + * @throws SQLServerException when an error occurs + */ + private int getAKVKeySize(String masterKeyPath) throws SQLServerException { + KeyVaultKey retrievedKey = getKeyVaultKey(masterKeyPath); + return retrievedKey.getKey().getN().length; } - return retrievedKey.key().n().length; - } + private KeyVaultKey getKeyVaultKey(String masterKeyPath) + throws SQLServerException { + String[] keyTokens = masterKeyPath.split("/"); + String keyName = keyTokens[keyTokens.length - 2]; + String keyVersion = keyTokens[keyTokens.length - 1]; + KeyClient keyClient = getKeyClient(masterKeyPath); + KeyVaultKey retrievedKey = keyClient.getKey(keyName, keyVersion); + + if (null == retrievedKey) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyNotFound")); + Object[] msgArgs = {keyTokens[keyTokens.length - 1]}; + throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + } - @Override - public boolean verifyColumnMasterKeyMetadata(String masterKeyPath, boolean allowEnclaveComputations, - byte[] signature) throws SQLServerException { - if (!allowEnclaveComputations) - return false; + if (retrievedKey.getKeyType() != KeyType.RSA && retrievedKey.getKeyType() != KeyType.RSA_HSM) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey")); + Object[] msgArgs = {retrievedKey.getKeyType().toString()}; + throw new SQLServerException(null, form.format(msgArgs), null, 0, false); + } + return retrievedKey; + } - KeyStoreProviderCommon.validateNonEmptyMasterKeyPath(masterKeyPath); + private KeyClient getKeyClient(String masterKeyPath) { + if (cachedKeyClients.containsKey(masterKeyPath)) { + return cachedKeyClients.get(masterKeyPath); + } + String vaultUrl = getVaultUrl(masterKeyPath); - try { - MessageDigest md = MessageDigest.getInstance("SHA-256"); - md.update(name.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); - md.update(masterKeyPath.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); - // value of allowEnclaveComputations is always true here - md.update("true".getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); + KeyClient keyClient; + if (credential != null) { + keyClient = new KeyClientBuilder().credential(credential).vaultUrl(vaultUrl).buildClient(); + } else { + keyClient = new KeyClientBuilder().pipeline(keyVaultPipeline).vaultUrl(vaultUrl).buildClient(); + } + cachedKeyClients.putIfAbsent(masterKeyPath, keyClient); + return cachedKeyClients.get(masterKeyPath); + } - byte[] dataToVerify = md.digest(); - if (null == dataToVerify) { - throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null); - } + private static String getVaultUrl(String masterKeyPath) { + String[] keyTokens = masterKeyPath.split("/"); + String hostName = keyTokens[2]; + return "https://" + hostName; + } - // Sign the hash - byte[] signedHash = AzureKeyVaultSignHashedData(dataToVerify, masterKeyPath); - if (null == signedHash) { - throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null); + @Override public boolean verifyColumnMasterKeyMetadata(String masterKeyPath, boolean allowEnclaveComputations, + byte[] signature) throws SQLServerException { + if (!allowEnclaveComputations) { + return false; } - // Validate the signature - return AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath); - } catch (NoSuchAlgorithmException e) { - throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); + KeyStoreProviderCommon.validateNonEmptyMasterKeyPath(masterKeyPath); + + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + md.update(name.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); + md.update(masterKeyPath.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); + // value of allowEnclaveComputations is always true here + md.update("true".getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); + + byte[] dataToVerify = md.digest(); + if (null == dataToVerify) { + throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null); + } + + // Sign the hash + byte[] signedHash = AzureKeyVaultSignHashedData(dataToVerify, masterKeyPath); + if (null == signedHash) { + throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), + null); + } + + // Validate the signature + return AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath); + } catch (NoSuchAlgorithmException e) { + throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); + } } - } - - private static List getTrustedEndpoints() { - Properties mssqlJdbcProperties = getMssqlJdbcProperties(); - List trustedEndpoints = new ArrayList(); - boolean append = true; - if (null != mssqlJdbcProperties) { - String endpoints = mssqlJdbcProperties.getProperty(AKV_TRUSTED_ENDPOINTS_KEYWORD); - if (null != endpoints && !endpoints.trim().isEmpty()) { - endpoints = endpoints.trim(); - // Append if the list starts with a semicolon. - if (';' != endpoints.charAt(0)) { - append = false; - } else { - endpoints = endpoints.substring(1); + + private static List getTrustedEndpoints() { + Properties mssqlJdbcProperties = getMssqlJdbcProperties(); + List trustedEndpoints = new ArrayList(); + boolean append = true; + if (null != mssqlJdbcProperties) { + String endpoints = mssqlJdbcProperties.getProperty(AKV_TRUSTED_ENDPOINTS_KEYWORD); + if (null != endpoints && !endpoints.trim().isEmpty()) { + endpoints = endpoints.trim(); + // Append if the list starts with a semicolon. + if (';' != endpoints.charAt(0)) { + append = false; + } else { + endpoints = endpoints.substring(1); + } + String[] entries = endpoints.split(";"); + for (String entry : entries) { + if (null != entry && !entry.trim().isEmpty()) { + trustedEndpoints.add(entry.trim()); + } + } + } } - String[] entries = endpoints.split(";"); - for (String entry : entries) { - if (null != entry && !entry.trim().isEmpty()) { - trustedEndpoints.add(entry.trim()); - } + /* + * List of Azure trusted endpoints + * https://docs.microsoft.com/en-us/azure/key-vault/key-vault-secure-your-key-vault + */ + if (append) { + trustedEndpoints.add("vault.azure.net"); + trustedEndpoints.add("vault.azure.cn"); + trustedEndpoints.add("vault.usgovcloudapi.net"); + trustedEndpoints.add("vault.microsoftazure.de"); } - } + return trustedEndpoints; } - /* - * List of Azure trusted endpoints - * https://docs.microsoft.com/en-us/azure/key-vault/key-vault-secure-your-key-vault + + /** + * Attempt to read MSSQL_JDBC_PROPERTIES. + * + * @return corresponding Properties object or null if failed to read the file. */ - if (append) { - trustedEndpoints.add("vault.azure.net"); - trustedEndpoints.add("vault.azure.cn"); - trustedEndpoints.add("vault.usgovcloudapi.net"); - trustedEndpoints.add("vault.microsoftazure.de"); - } - return trustedEndpoints; - } - - /** - * Attempt to read MSSQL_JDBC_PROPERTIES. - * - * @return corresponding Properties object or null if failed to read the file. - */ - private static Properties getMssqlJdbcProperties() { - Properties props = null; - try (FileInputStream in = new FileInputStream(MSSQL_JDBC_PROPERTIES)) { - props = new Properties(); - props.load(in); - } catch (IOException e) { - if (akvLogger.isLoggable(Level.FINER)) { - akvLogger.finer("Unable to load the mssql-jdbc.properties file: " + e); - } + private static Properties getMssqlJdbcProperties() { + Properties props = null; + try (FileInputStream in = new FileInputStream(MSSQL_JDBC_PROPERTIES)) { + props = new Properties(); + props.load(in); + } catch (IOException e) { + if (akvLogger.isLoggable(Level.FINER)) { + akvLogger.finer("Unable to load the mssql-jdbc.properties file: " + e); + } + } + return (null != props && !props.isEmpty()) ? props : null; } - return (null != props && !props.isEmpty()) ? props : null; - } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 61a0e29ab..f236258bf 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -669,6 +669,8 @@ boolean getSendTemporalDataTypesAsStringForBulkCopy() { String keyStoreLocation = null; String keyStorePrincipalId = null; + String keyVaultProviderTenantId = null; + private ColumnEncryptionVersion serverColumnEncryptionVersion = ColumnEncryptionVersion.AE_NotSupported; private String enclaveType = null; @@ -1366,16 +1368,11 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke } break; case KeyVaultClientSecret: - // need a secret use use the secret method + // need a secret to use the secret method if (null == keyStoreSecret) { throw new SQLServerException(SQLServerException.getErrString("R_keyStoreSecretNotSet"), null); - } else { - SQLServerColumnEncryptionAzureKeyVaultProvider provider = new SQLServerColumnEncryptionAzureKeyVaultProvider( - keyStorePrincipalId, keyStoreSecret); - Map keyStoreMap = new HashMap(); - keyStoreMap.put(provider.getName(), provider); - registerColumnEncryptionKeyStoreProviders(keyStoreMap); } + registerKeyVaultProvider(keyStorePrincipalId, keyStoreSecret); break; case KeyVaultManagedIdentity: SQLServerColumnEncryptionAzureKeyVaultProvider provider; @@ -1384,7 +1381,7 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke } else { provider = new SQLServerColumnEncryptionAzureKeyVaultProvider(); } - Map keyStoreMap = new HashMap(); + Map keyStoreMap = new HashMap<>(); keyStoreMap.put(provider.getName(), provider); registerColumnEncryptionKeyStoreProviders(keyStoreMap); break; @@ -1395,6 +1392,15 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke } } + private void registerKeyVaultProvider(String clientId, String clientKey) throws SQLServerException { + // need a secret to use the secret method + SQLServerColumnEncryptionAzureKeyVaultProvider provider = new SQLServerColumnEncryptionAzureKeyVaultProvider( + clientId, clientKey); + Map keyStoreMap = new HashMap<>(); + keyStoreMap.put(provider.getName(), provider); + registerColumnEncryptionKeyStoreProviders(keyStoreMap); + } + /** * Establish a physical database connection based on the user specified connection properties. Logon to the * database. @@ -1620,6 +1626,12 @@ Connection connectInternal(Properties propsIn, keyStorePrincipalId = sPropValue; } + sPropKey = SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_TENANT_ID.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + keyVaultProviderTenantId = sPropValue; + } + registerKeyStoreProviderOnConnection(keyStoreAuthentication, keyStoreSecret, keyStoreLocation); if (null == globalCustomColumnEncryptionKeyStoreProviders) { @@ -1631,11 +1643,10 @@ Connection connectInternal(Properties propsIn, sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null != sPropValue) { String keyVaultColumnEncryptionProviderClientKey = sPropValue; - SQLServerColumnEncryptionAzureKeyVaultProvider akvProvider = new SQLServerColumnEncryptionAzureKeyVaultProvider( - keyVaultColumnEncryptionProviderClientId, keyVaultColumnEncryptionProviderClientKey); - Map keyStoreMap = new HashMap(); - keyStoreMap.put(akvProvider.getName(), akvProvider); - registerColumnEncryptionKeyStoreProviders(keyStoreMap); + + registerKeyVaultProvider( + keyVaultColumnEncryptionProviderClientId, + keyVaultColumnEncryptionProviderClientKey); } } } @@ -4435,11 +4446,11 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe while (true) { if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryPassword.toString())) { - if (!adalContextExists()) { + if (!msalContextExists()) { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_ADALMissing")); throw new SQLServerException(form.format(new Object[] {authenticationString}), null, 0, null); } - fedAuthToken = SQLServerADAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, + fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), authenticationString); @@ -4521,17 +4532,17 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe sleepInterval = sleepInterval * 2; } } - // else choose ADAL4J for integrated authentication. This option is supported for both windows and unix, + // else choose MSAL4J for integrated authentication. This option is supported for both windows and unix, // so we don't need to check the // OS version here. else { - // Check if ADAL4J library is available - if (!adalContextExists()) { + // Check if MSAL4J library is available + if (!msalContextExists()) { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_DLLandADALMissing")); Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString}; throw new SQLServerException(form.format(msgArgs), null, 0, null); } - fedAuthToken = SQLServerADAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); + fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); } // Break out of the retry loop in successful case. break; @@ -4541,9 +4552,9 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe return fedAuthToken; } - private boolean adalContextExists() { + private boolean msalContextExists() { try { - Class.forName("com.microsoft.aad.adal4j.AuthenticationContext"); + Class.forName("com.microsoft.aad.msal4j.PublicClientApplication"); } catch (ClassNotFoundException e) { return false; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index d9f41510a..4eb192037 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -363,6 +363,7 @@ enum SQLServerDriverStringProperty { MSI_CLIENT_ID("msiClientId", ""), KEY_VAULT_PROVIDER_CLIENT_ID("keyVaultProviderClientId", ""), KEY_VAULT_PROVIDER_CLIENT_KEY("keyVaultProviderClientKey", ""), + KEY_VAULT_PROVIDER_TENANT_ID("keyVaultProviderTenantId", ""), KEY_STORE_PRINCIPAL_ID("keyStorePrincipalId", ""), CLIENT_CERTIFICATE("clientCertificate", ""), CLIENT_KEY("clientKey", ""), diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerKeyVaultAuthenticationCallback.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerKeyVaultAuthenticationCallback.java deleted file mode 100644 index c7ac13f5c..000000000 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerKeyVaultAuthenticationCallback.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made - * available under the terms of the MIT License. See the LICENSE file in the project root for more information. - */ - -package com.microsoft.sqlserver.jdbc; - -/** - * Provides a callback delegate which is to be implemented by the client code - * - */ -public interface SQLServerKeyVaultAuthenticationCallback { - - /** - * Returns the acesss token of the authentication request - * - * @param authority - * - Identifier of the authority, a URL. - * @param resource - * - Identifier of the target resource that is the recipient of the requested token, a URL. - * @param scope - * - The scope of the authentication request. - * @return access token - */ - String getAccessToken(String authority, String resource, String scope); -} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java similarity index 62% rename from src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java rename to src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index a94ca3cc5..e8a318375 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -8,37 +8,46 @@ import java.io.IOException; import java.net.MalformedURLException; import java.text.MessageFormat; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; import java.util.logging.Level; import javax.security.auth.kerberos.KerberosPrincipal; -import com.microsoft.aad.adal4j.AuthenticationContext; -import com.microsoft.aad.adal4j.AuthenticationException; -import com.microsoft.aad.adal4j.AuthenticationResult; +import com.microsoft.aad.msal4j.IAuthenticationResult; +import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters; +import com.microsoft.aad.msal4j.PublicClientApplication; +import com.microsoft.aad.msal4j.UserNamePasswordParameters; import com.microsoft.sqlserver.jdbc.SQLServerConnection.ActiveDirectoryAuthentication; import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo; -class SQLServerADAL4JUtils { +class SQLServerMSAL4JUtils { - static final private java.util.logging.Logger adal4jLogger = java.util.logging.Logger - .getLogger("com.microsoft.sqlserver.jdbc.internals.SQLServerADAL4JUtils"); + static final private java.util.logging.Logger logger = java.util.logging.Logger + .getLogger("com.microsoft.sqlserver.jdbc.SQLServerMSAL4JUtils"); static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, String authenticationString) throws SQLServerException { ExecutorService executorService = Executors.newFixedThreadPool(1); try { - AuthenticationContext context = new AuthenticationContext(fedAuthInfo.stsurl, false, executorService); - Future future = context.acquireToken(fedAuthInfo.spn, - ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, user, password, null); + final PublicClientApplication clientApplication = PublicClientApplication + .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID) + .executorService(executorService) + .authority(fedAuthInfo.stsurl) + .build(); + final CompletableFuture future = clientApplication.acquireToken(UserNamePasswordParameters.builder( + Collections.singleton(fedAuthInfo.spn + "/.default"), + user, + password.toCharArray() + ).build()); + + final IAuthenticationResult authenticationResult = future.get(); + return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); - AuthenticationResult authenticationResult = future.get(); - - return new SqlFedAuthToken(authenticationResult.getAccessToken(), authenticationResult.getExpiresOnDate()); } catch (MalformedURLException | InterruptedException e) { throw new SQLServerException(e.getMessage(), e); } catch (ExecutionException e) { @@ -50,8 +59,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use * correct format */ String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n"); - AuthenticationException correctedAuthenticationException = new AuthenticationException( - correctedErrorMessage); + RuntimeException correctedAuthenticationException = new RuntimeException(correctedErrorMessage); /* * SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to match @@ -75,19 +83,24 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, * principal_name@realm_name format */ KerberosPrincipal kerberosPrincipal = new KerberosPrincipal("username"); - String username = kerberosPrincipal.getName(); + String user = kerberosPrincipal.getName(); - if (adal4jLogger.isLoggable(Level.FINE)) { - adal4jLogger.fine(adal4jLogger.toString() + " realm name is:" + kerberosPrincipal.getRealm()); + if (logger.isLoggable(Level.FINE)) { + logger.fine(logger.toString() + " realm name is:" + kerberosPrincipal.getRealm()); } - AuthenticationContext context = new AuthenticationContext(fedAuthInfo.stsurl, false, executorService); - Future future = context.acquireToken(fedAuthInfo.spn, - ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, username, null, null); - - AuthenticationResult authenticationResult = future.get(); - - return new SqlFedAuthToken(authenticationResult.getAccessToken(), authenticationResult.getExpiresOnDate()); + final PublicClientApplication clientApplication = PublicClientApplication + .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID) + .executorService(executorService) + .authority(fedAuthInfo.stsurl) + .build(); + final CompletableFuture future = clientApplication.acquireToken(IntegratedWindowsAuthenticationParameters.builder( + Collections.singleton(fedAuthInfo.spn + "/.default"), + user + ).build()); + + final IAuthenticationResult authenticationResult = future.get(); + return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); } catch (InterruptedException | IOException e) { throw new SQLServerException(e.getMessage(), e); } catch (ExecutionException e) { @@ -103,8 +116,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, * correct format */ String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n"); - AuthenticationException correctedAuthenticationException = new AuthenticationException( - correctedErrorMessage); + RuntimeException correctedAuthenticationException = new RuntimeException(correctedErrorMessage); /* * SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index a361c0f59..23cb72514 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -537,6 +537,8 @@ protected Object[][] getContents() { "Both \"keyStoreSecret\" and \"keyStoreLocation\" must be set, if \"keyStoreAuthentication=JavaKeyStorePassword\" has been specified in the connection string."}, {"R_keyStoreSecretNotSet", "\"keyStoreSecret\" must be set, if \"keyStoreAuthentication=KeyVaultClientSecret\" has been specified in the connection string."}, + {"R_keyVaultProviderTenantIdNotSet", + "\"keyVaultProviderTenantId\" must be set, if \"keyStoreAuthentication=KeyVaultClientSecret\" has been specified in the connection string."}, {"R_certificateStoreInvalidKeyword", "Cannot set \"keyStoreSecret\", if \"keyStoreAuthentication=CertificateStore\" has been specified in the connection string."}, {"R_certificateStoreLocationNotSet", diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ScopeTokenCache.java b/src/main/java/com/microsoft/sqlserver/jdbc/ScopeTokenCache.java new file mode 100644 index 000000000..a3758fb79 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ScopeTokenCache.java @@ -0,0 +1,57 @@ +package com.microsoft.sqlserver.jdbc; + +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenRequestContext; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.ReplayProcessor; + + +/** + * A token cache that supports caching a token and refreshing it. + */ +class ScopeTokenCache { + + private final AtomicBoolean wip; + private AccessToken cache; + private final ReplayProcessor emitterProcessor = ReplayProcessor.create(1); + private final FluxSink sink = emitterProcessor.sink(FluxSink.OverflowStrategy.BUFFER); + private final Function> getNew; + private TokenRequestContext request; + + /** + * Creates an instance of RefreshableTokenCredential with default scheme "Bearer". + * + * @param getNew a method to get a new token + */ + ScopeTokenCache(Function> getNew) { + this.wip = new AtomicBoolean(false); + this.getNew = getNew; + } + + void setRequest(TokenRequestContext request) { + this.request = request; + } + + /** + * Asynchronously get a token from either the cache or replenish the cache with a new token. + * @return a Publisher that emits an AccessToken + */ + Mono getToken() { + if (cache != null && !cache.isExpired()) { + return Mono.just(cache); + } + return Mono.defer(() -> { + if (!wip.getAndSet(true)) { + return getNew.apply(request).doOnNext(ac -> cache = ac) + .doOnNext(sink::next) + .doOnError(sink::error) + .doOnTerminate(() -> wip.set(false)); + } else { + return emitterProcessor.next(); + } + }); + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java index ba4125f05..c2fe9e4f9 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java @@ -4,6 +4,7 @@ */ package com.microsoft.sqlserver.jdbc.AlwaysEncrypted; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -17,25 +18,19 @@ import java.sql.Time; import java.sql.Timestamp; import java.util.LinkedList; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; +import com.azure.core.credential.TokenCredential; import org.junit.jupiter.api.Tag; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import com.microsoft.aad.adal4j.AuthenticationContext; -import com.microsoft.aad.adal4j.AuthenticationResult; -import com.microsoft.aad.adal4j.ClientCredential; import com.microsoft.sqlserver.jdbc.RandomData; import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider; import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionJavaKeyStoreProvider; import com.microsoft.sqlserver.jdbc.SQLServerConnection; import com.microsoft.sqlserver.jdbc.SQLServerException; -import com.microsoft.sqlserver.jdbc.SQLServerKeyVaultAuthenticationCallback; import com.microsoft.sqlserver.jdbc.SQLServerPreparedStatement; import com.microsoft.sqlserver.jdbc.SQLServerResultSet; import com.microsoft.sqlserver.jdbc.SQLServerStatement; @@ -91,18 +86,15 @@ public void testJksName(String serverName, String url, String protocol) throws E */ @ParameterizedTest @MethodSource("enclaveParams") + @Tag(Constants.reqExternalSetup) public void testAkvName(String serverName, String url, String protocol) throws Exception { setAEConnectionString(serverName, url, protocol); - try { - SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider( - authenticationCallback); - String keystoreName = "keystoreName"; - akv.setName(keystoreName); - assertTrue(akv.getName().equals(keystoreName)); - } catch (SQLServerException e) { - fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); - } + SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider( + applicationClientID, applicationKey); + String keystoreName = "keystoreName"; + akv.setName(keystoreName); + assertTrue(akv.getName().equals(keystoreName)); } /* @@ -134,10 +126,10 @@ public void testBadAkv(String serverName, String url, String protocol) throws Ex try { SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider( - (SQLServerKeyVaultAuthenticationCallback) null); + (TokenCredential) null); fail(TestResource.getResource("R_expectedExceptionNotThrown")); - } catch (SQLServerException e) { - assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_NullValue"))); + } catch (NullPointerException exception) { + assertNull(exception.getMessage()); } } @@ -185,11 +177,7 @@ public void testAkvBadEncryptColumnEncryptionKey(String serverName, String url, setAEConnectionString(serverName, url, protocol); SQLServerColumnEncryptionAzureKeyVaultProvider akv = null; - try { - akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback); - } catch (SQLServerException e) { - fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); - } + akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey); // null encryptedColumnEncryptionKey try { @@ -268,11 +256,7 @@ public void testAkvDecryptColumnEncryptionKey(String serverName, String url, Str setAEConnectionString(serverName, url, protocol); SQLServerColumnEncryptionAzureKeyVaultProvider akv = null; - try { - akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback); - } catch (SQLServerException e) { - fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); - } + akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey); // null akvpath try { @@ -2243,23 +2227,4 @@ void testNumerics(SQLServerStatement stmt, String cekName, String[][] table, Str testRichQuery(stmt, NUMERIC_TABLE_AE, table, values2); } } - - SQLServerKeyVaultAuthenticationCallback authenticationCallback = new SQLServerKeyVaultAuthenticationCallback() { - // @Override - ExecutorService service = Executors.newFixedThreadPool(2); - - public String getAccessToken(String authority, String resource, String scope) { - - AuthenticationResult result = null; - try { - AuthenticationContext context = new AuthenticationContext(authority, false, service); - ClientCredential cred = new ClientCredential(applicationClientID, applicationKey); - Future future = context.acquireToken(resource, cred, null); - result = future.get(); - } catch (Exception e) { - fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); - } - return result.getAccessToken(); - } - }; } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java index 552ef2012..b4c28a17c 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.microsoft.aad.msal4j.MsalServiceException; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -299,7 +300,9 @@ public void testNumericAkvWithBadCred() throws SQLException { testNumericAKV(connStr); fail(TestResource.getResource("R_expectedFailPassed")); } catch (Exception e) { - assert (e.getMessage().contains("AuthenticationException")); + assertTrue(e.getCause() instanceof MsalServiceException); + // https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes + assertTrue(e.getMessage().contains("AADSTS700016")); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java index 10e4d82fc..0c85e7ea5 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java @@ -307,11 +307,10 @@ public static void testVerifyCMKNoEnclave() { } try { - SQLServerColumnEncryptionAzureKeyVaultProvider aksp = new SQLServerColumnEncryptionAzureKeyVaultProvider("", - ""); - assertFalse(aksp.verifyColumnMasterKeyMetadata(null, false, null)); + SQLServerColumnEncryptionAzureKeyVaultProvider aksp = new SQLServerColumnEncryptionAzureKeyVaultProvider( + "",""); } catch (SQLServerException e) { - fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + assertEquals(e.getMessage(), "Client ID cannot be null."); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java index 675fa4341..30c696d3d 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java @@ -7,22 +7,27 @@ import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import com.microsoft.aad.msal4j.IAuthenticationResult; +import com.microsoft.aad.msal4j.PublicClientApplication; +import com.microsoft.aad.msal4j.UserNamePasswordParameters; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.Collections; +import java.util.Date; import java.util.Locale; +import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; -import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.logging.LogManager; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; -import com.microsoft.aad.adal4j.AuthenticationContext; -import com.microsoft.aad.adal4j.AuthenticationResult; import com.microsoft.sqlserver.testframework.Constants; import com.microsoft.sqlserver.jdbc.SQLServerException; import com.microsoft.sqlserver.jdbc.TestResource; @@ -142,11 +147,24 @@ public static void getConfigs() throws Exception { */ static void getFedauthInfo() { try { - AuthenticationContext context = new AuthenticationContext(stsurl, false, Executors.newFixedThreadPool(1)); - Future future = context.acquireToken(spn, fedauthClientId, azureUserName, - azurePassword, null); - secondsBeforeExpiration = future.get().getExpiresAfter(); - accessToken = future.get().getAccessToken(); + + final PublicClientApplication clientApplication = PublicClientApplication + .builder(fedauthClientId) + .executorService(Executors.newFixedThreadPool(1)) + .authority(stsurl) + .build(); + final CompletableFuture future = clientApplication.acquireToken( + UserNamePasswordParameters.builder( + Collections.singleton(spn + "/.default"), + azureUserName, + azurePassword.toCharArray() + ).build()); + + final IAuthenticationResult authenticationResult = future.get(); + + secondsBeforeExpiration = TimeUnit.MILLISECONDS + .toSeconds(authenticationResult.expiresOnDate().getTime() - new Date().getTime()); + accessToken = authenticationResult.accessToken(); } catch (Exception e) { fail(e.getMessage()); } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java index 65f1c2ac0..08942715b 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java @@ -15,9 +15,6 @@ import java.sql.Statement; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -26,9 +23,6 @@ import org.junit.platform.runner.JUnitPlatform; import org.junit.runner.RunWith; -import com.microsoft.aad.adal4j.AuthenticationContext; -import com.microsoft.aad.adal4j.AuthenticationResult; -import com.microsoft.aad.adal4j.ClientCredential; import com.microsoft.sqlserver.jdbc.RandomUtil; import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider; import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionJavaKeyStoreProvider; @@ -36,7 +30,6 @@ import com.microsoft.sqlserver.jdbc.SQLServerConnection; import com.microsoft.sqlserver.jdbc.SQLServerDataSource; import com.microsoft.sqlserver.jdbc.SQLServerException; -import com.microsoft.sqlserver.jdbc.SQLServerKeyVaultAuthenticationCallback; import com.microsoft.sqlserver.jdbc.TestUtils; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.Constants; @@ -128,7 +121,7 @@ public void testFedAuthWithAE_AKV() throws SQLException { dropCMK(stmt, cmkName3); setupCMK_AKVOld(cmkName3, stmt); - createCEK(cmkName3, setupKeyStoreProvider_AKVOld(), stmt, keyIDs[0]); + createCEK(cmkName3, setupKeyStoreProvider_AKVNew(), stmt, keyIDs[0]); createCharTable(stmt, charTableOld); populateCharNormalCase(charValues, connection, charTableOld); @@ -307,27 +300,27 @@ private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKVNew() new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey)); } - private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKVOld() throws SQLServerException { - ExecutorService service = Executors.newFixedThreadPool(2); - SQLServerKeyVaultAuthenticationCallback authenticationCallback = new SQLServerKeyVaultAuthenticationCallback() { - @Override - public String getAccessToken(String authority, String resource, String scope) { - AuthenticationResult result = null; - try { - AuthenticationContext context = new AuthenticationContext(authority, false, service); - ClientCredential cred = new ClientCredential(applicationClientID, applicationKey); - - Future future = context.acquireToken(resource, cred, null); - result = future.get(); - return result.getAccessToken(); - } catch (Exception e) { - fail(e.getMessage()); - return null; - } - } - }; - return new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback); - } +// private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKVOld() throws SQLServerException { +// ExecutorService service = Executors.newFixedThreadPool(2); +// SQLServerKeyVaultAuthenticationCallback authenticationCallback = new SQLServerKeyVaultAuthenticationCallback() { +// @Override +// public String getAccessToken(String authority, String resource, String scope) { +// AuthenticationResult result = null; +// try { +// AuthenticationContext context = new AuthenticationContext(authority, false, service); +// ClientCredential cred = new ClientCredential(applicationClientID, applicationKey); +// +// Future future = context.acquireToken(resource, cred, null); +// result = future.get(); +// return result.getAccessToken(); +// } catch (Exception e) { +// fail(e.getMessage()); +// return null; +// } +// } +// }; +// return new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback); +// } private SQLServerColumnEncryptionKeyStoreProvider registerAKVProvider( SQLServerColumnEncryptionKeyStoreProvider provider) throws SQLServerException { diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java index 914831a3c..56bd627f0 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java @@ -59,6 +59,7 @@ public abstract class AbstractTest { protected static String applicationClientID = null; protected static String applicationKey = null; protected static String[] keyIDs = null; + protected static String tenantID = null; protected static String[] enclaveServer = null; protected static String[] enclaveAttestationUrl = null; @@ -136,6 +137,7 @@ public static void setup() throws Exception { javaKeyPath = TestUtils.getCurrentClassPath() + Constants.JKS_NAME; keyIDs = getConfiguredProperty("keyID", "").split(Constants.SEMI_COLON); + tenantID = getConfiguredProperty("tenantID"); windowsKeyPath = getConfiguredProperty("windowsKeyPath"); String prop;