diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index a2f9116b2..ab090c4f6 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -38,8 +38,8 @@ jobs:
displayName: 'Maven build jre14'
inputs:
mavenPomFile: 'pom.xml'
- goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre14 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath)
--DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer)'
+ goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre14 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath)
+-DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer) -DtenantID=$(tenantID)'
testResultsFiles: '**/TEST-*.xml'
testRunTitle: 'Maven build jre14'
javaHomeOption: Path
@@ -49,7 +49,7 @@ jobs:
inputs:
mavenPomFile: 'pom.xml'
goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre11 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath)
--DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer)'
+-DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer) -DtenantID=$(tenantID)'
testResultsFiles: '**/TEST-*.xml'
testRunTitle: 'Maven build jre11'
javaHomeOption: Path
@@ -58,8 +58,8 @@ jobs:
displayName: 'Maven build jre8'
inputs:
mavenPomFile: 'pom.xml'
- goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre8 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath)
--DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer)'
+ goals: 'clean dependency:purge-local-repository -Dmssql_jdbc_test_connection_properties=jdbc:sqlserver://$(Target_SQL)$(server_domain);$(database);$(user);$(password); install -Pjre8 -DuserNTLM=$(userNTLM) -DpasswordNTLM=$(passwordNTLM) -DdomainNTLM=$(domainNTLM) -DexcludedGroups=$(Ex_Groups) -Dpkcs12_truststore_password=$(pkcs12_truststore_password) -Dpkcs12_truststore=$(pkcs12_truststore.secureFilePath)
+-DapplicationClientID=$(applicationClientID) -DapplicationKey=$(applicationKey) -DkeyID=$(keyID) -DwindowsKeyPath=$(windowsKeyPath) -DenclaveAttestationUrl=$(enclaveAttestationUrl) -DenclaveAttestationProtocol=$(enclaveAttestationProtocol) -DenclaveServer=$(enclaveServer) -DtenantID=$(tenantID)'
testResultsFiles: '**/TEST-*.xml'
testRunTitle: 'Maven build jre8'
javaHomeOption: Path
diff --git a/build.gradle b/build.gradle
index 7c0c08bbb..aef00225f 100644
--- a/build.gradle
+++ b/build.gradle
@@ -110,14 +110,12 @@ repositories {
dependencies {
compile 'org.osgi:org.osgi.core:6.0.0',
'org.osgi:org.osgi.compendium:5.0.0'
- compileOnly 'com.microsoft.azure:azure-keyvault:1.2.2',
- 'com.microsoft.azure:azure-keyvault-webkey:1.2.1',
- 'com.microsoft.rest:client-runtime:1.7.0',
- 'com.microsoft.azure:adal4j:1.6.4',
+ compileOnly 'com.azure:azure-security-keyvault-keys:4.2.0',
+ 'com.azure:azure-identity:1.1.0',
'org.antlr:antlr4-runtime:4.7.2',
'com.google.code.gson:gson:2.8.6',
- 'org.bouncycastle:bcprov-jdk15on:1.64',
- 'org.bouncycastle:bcpkix-jdk15on:1.64'
+ 'org.bouncycastle:bcprov-jdk15on:1.65',
+ 'org.bouncycastle:bcpkix-jdk15on:1.65'
testCompile 'org.junit.platform:junit-platform-console:1.5.2',
'org.junit.platform:junit-platform-commons:1.5.2',
'org.junit.platform:junit-platform-engine:1.5.2',
@@ -127,15 +125,14 @@ dependencies {
'org.junit.jupiter:junit-jupiter-api:5.6.0',
'org.junit.jupiter:junit-jupiter-engine:5.6.0',
'org.junit.jupiter:junit-jupiter-params:5.6.0',
- 'com.zaxxer:HikariCP:3.4.1',
- 'org.apache.commons:commons-dbcp2:2.6.0',
+ 'com.zaxxer:HikariCP:3.4.2',
+ 'org.apache.commons:commons-dbcp2:2.7.0',
'org.slf4j:slf4j-nop:1.7.29',
'org.antlr:antlr4-runtime:4.7.2',
'org.eclipse.gemini.blueprint:gemini-blueprint-mock:2.1.0.RELEASE',
'com.google.code.gson:gson:2.8.6',
- 'org.bouncycastle:bcprov-jdk15on:1.64',
- 'com.microsoft.azure:adal4j:1.6.4',
- 'com.microsoft.azure:azure-keyvault:1.2.2',
- 'com.microsoft.azure:azure-keyvault-webkey:1.2.1',
+ 'org.bouncycastle:bcprov-jdk15on:1.65',
+ 'com.azure:azure-security-keyvault-keys:4.2.0',
+ 'com.azure:azure-identity:1.1.0',
'com.h2database:h2:1.4.200'
}
diff --git a/pom.xml b/pom.xml
index feb5cbc75..fef1f5b8d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -60,9 +60,8 @@
-preview
- 1.2.4
- 1.6.5
- 1.7.4
+ 4.1.4
+ 1.0.7
6.0.0
5.0.0
4.7.2
@@ -86,22 +85,14 @@
- com.microsoft.azure
- azure-keyvault
+ com.azure
+ azure-security-keyvault-keys
${azure.keyvault.version}
- true
- com.microsoft.azure
- adal4j
- ${azure.adal4j.version}
- true
-
-
- com.microsoft.rest
- client-runtime
- ${rest.client.version}
- true
+ com.azure
+ azure-identity
+ ${azure.identity.version}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java
index ca4f21f68..0e8068780 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java
@@ -5,86 +5,117 @@
package com.microsoft.sqlserver.jdbc;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-
-import com.microsoft.aad.adal4j.AuthenticationContext;
-import com.microsoft.aad.adal4j.AuthenticationResult;
-import com.microsoft.aad.adal4j.ClientCredential;
-import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials;
-
+import com.azure.core.annotation.Immutable;
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenCredential;
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.core.util.logging.ClientLogger;
+import com.microsoft.aad.msal4j.ClientCredentialFactory;
+import com.microsoft.aad.msal4j.ClientCredentialParameters;
+import com.microsoft.aad.msal4j.ConfidentialClientApplication;
+import com.microsoft.aad.msal4j.IAuthenticationResult;
+import com.microsoft.aad.msal4j.IClientCredential;
+import com.microsoft.aad.msal4j.SilentParameters;
+import java.net.MalformedURLException;
+import java.time.OffsetDateTime;
+import java.time.ZoneOffset;
+import java.util.HashSet;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import reactor.core.publisher.Mono;
/**
- *
- * An implementation of ServiceClientCredentials that supports automatic bearer token refresh.
- *
+ * An AAD credential that acquires a token with a client secret for an AAD application.
*/
-class KeyVaultCredential extends KeyVaultCredentials {
-
- SQLServerKeyVaultAuthenticationCallback authenticationCallback = null;
- String clientId = null;
- String clientKey = null;
- String accessToken = null;
+@Immutable
+class KeyVaultCredential implements TokenCredential {
+ private final ClientLogger logger = new ClientLogger(KeyVaultCredential.class);
+ private final String clientId;
+ private final String clientSecret;
+ private String authorization;
+ private ConfidentialClientApplication confidentialClientApplication;
- KeyVaultCredential(String clientId) throws SQLServerException {
+ /**
+ * Creates a KeyVaultCredential with the given identity client options.
+ *
+ * @param clientId the client ID of the application
+ * @param clientSecret the secret value of the AAD application.
+ */
+ KeyVaultCredential(String clientId, String clientSecret) {
+ Objects.requireNonNull(clientSecret, "'clientSecret' cannot be null.");
+ Objects.requireNonNull(clientSecret, "'clientId' cannot be null.");
this.clientId = clientId;
+ this.clientSecret = clientSecret;
}
- KeyVaultCredential() {}
-
- KeyVaultCredential(String clientId, String clientKey) {
- this.clientId = clientId;
- this.clientKey = clientKey;
+ @Override
+ public Mono getToken(TokenRequestContext request) {
+ return authenticateWithConfidentialClientCache(request)
+ .onErrorResume(t -> Mono.empty())
+ .switchIfEmpty(Mono.defer(() -> authenticateWithConfidentialClient(request)));
}
- KeyVaultCredential(SQLServerKeyVaultAuthenticationCallback authenticationCallback) {
- this.authenticationCallback = authenticationCallback;
+ public KeyVaultCredential setAuthorization(String authorization) {
+ if (this.authorization != null && this.authorization.equals(authorization)) {
+ return this;
+ }
+ this.authorization = authorization;
+ confidentialClientApplication = getConfidentialClientApplication();
+ return this;
}
- public String doAuthenticate(String authorization, String resource, String scope) {
- String accessToken = null;
- if (null == authenticationCallback) {
- if (null == clientKey) {
- try {
- SqlFedAuthToken token = SQLServerSecurityUtility.getMSIAuthToken(resource, clientId);
- accessToken = (null != token) ? token.accessToken : null;
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- } else {
- AuthenticationResult token = getAccessTokenFromClientCredentials(authorization, resource, clientId,
- clientKey);
- accessToken = token.getAccessToken();
- }
+ private ConfidentialClientApplication getConfidentialClientApplication() {
+ if (clientId == null) {
+ throw logger.logExceptionAsError(new IllegalArgumentException(
+ "A non-null value for client ID must be provided for user authentication."));
+ }
+
+ if (authorization == null) {
+ throw logger.logExceptionAsError(new IllegalArgumentException(
+ "A non-null value for authorization must be provided for user authentication."));
+ }
+
+ IClientCredential credential;
+ if (clientSecret != null) {
+ credential = ClientCredentialFactory.create(clientSecret);
} else {
- accessToken = authenticationCallback.getAccessToken(authorization, resource, scope);
+ throw logger.logExceptionAsError(
+ new IllegalArgumentException("Must provide client secret."));
+ }
+ ConfidentialClientApplication.Builder applicationBuilder =
+ ConfidentialClientApplication.builder(clientId, credential);
+ try {
+ applicationBuilder = applicationBuilder.authority(authorization);
+ } catch (MalformedURLException e) {
+ throw logger.logExceptionAsWarning(new IllegalStateException(e));
}
- return accessToken;
+ return applicationBuilder.build();
}
- private static AuthenticationResult getAccessTokenFromClientCredentials(String authorization, String resource,
- String clientId, String clientKey) {
- AuthenticationContext context = null;
- AuthenticationResult result = null;
- ExecutorService service = null;
- try {
- service = Executors.newFixedThreadPool(1);
- context = new AuthenticationContext(authorization, false, service);
- ClientCredential credentials = new ClientCredential(clientId, clientKey);
- Future future = context.acquireToken(resource, credentials, null);
- result = future.get();
- } catch (Exception e) {
- throw new RuntimeException(e);
- } finally {
- if (null != service) {
- service.shutdown();
+ private Mono authenticateWithConfidentialClientCache(TokenRequestContext request) {
+ return Mono.fromFuture(() -> {
+ SilentParameters.SilentParametersBuilder parametersBuilder = SilentParameters
+ .builder(new HashSet<>(request.getScopes()));
+ try {
+ return confidentialClientApplication.acquireTokenSilently(parametersBuilder.build());
+ } catch (MalformedURLException e) {
+ return getFailedCompletableFuture(logger.logExceptionAsError(new RuntimeException(e)));
}
- }
+ }).map(ar -> new AccessToken(ar.accessToken(),
+ OffsetDateTime.ofInstant(ar.expiresOnDate().toInstant(), ZoneOffset.UTC)))
+ .filter(t -> !t.isExpired());
+ }
- if (null == result) {
- throw new RuntimeException("authentication result was null");
- }
- return result;
+ private CompletableFuture getFailedCompletableFuture(Exception e) {
+ CompletableFuture completableFuture = new CompletableFuture<>();
+ completableFuture.completeExceptionally(e);
+ return completableFuture;
+ }
+
+ private Mono authenticateWithConfidentialClient(TokenRequestContext request) {
+ return Mono.fromFuture(() -> confidentialClientApplication
+ .acquireToken(ClientCredentialParameters.builder(new HashSet<>(request.getScopes())).build()))
+ .map(ar -> new AccessToken(ar.accessToken(),
+ OffsetDateTime.ofInstant(ar.expiresOnDate().toInstant(), ZoneOffset.UTC)));
}
}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java
new file mode 100644
index 000000000..45bec37b2
--- /dev/null
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCustomCredentialPolicy.java
@@ -0,0 +1,102 @@
+/*
+ * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
+ * available under the terms of the MIT License. See the LICENSE file in the project root for more information.
+ */
+
+package com.microsoft.sqlserver.jdbc;
+
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.core.http.HttpPipelineCallContext;
+import com.azure.core.http.HttpPipelineNextPolicy;
+import com.azure.core.http.HttpResponse;
+import com.azure.core.http.policy.HttpPipelinePolicy;
+import com.azure.core.util.CoreUtils;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import reactor.core.publisher.Mono;
+
+/**
+ * A policy that authenticates requests with Azure Key Vault service.
+ */
+class KeyVaultCustomCredentialPolicy implements HttpPipelinePolicy {
+ private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
+ private static final String BEARER_TOKEN_PREFIX = "Bearer ";
+ private static final String AUTHORIZATION = "Authorization";
+ private final ScopeTokenCache cache;
+ private final KeyVaultCredential keyVaultCredential;
+
+ /**
+ * Creates KeyVaultCustomCredentialPolicy.
+ *
+ * @param credential the token credential to authenticate the request
+ */
+ public KeyVaultCustomCredentialPolicy(KeyVaultCredential credential) {
+ Objects.requireNonNull(credential, "'credential' cannot be null.");
+ this.cache = new ScopeTokenCache(credential::getToken);
+ this.keyVaultCredential = credential;
+ }
+
+ /**
+ * Adds the required header to authenticate a request to Azure Key Vault service.
+ *
+ * @param context The request context
+ * @param next The next HTTP pipeline policy to process the {@code context's} request after this policy completes.
+ * @return A {@link Mono} representing the HTTP response that will arrive asynchronously.
+ */
+ @Override
+ public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
+ if ("http".equals(context.getHttpRequest().getUrl().getProtocol())) {
+ return Mono.error(new RuntimeException("Token credentials require a URL using the HTTPS protocol scheme"));
+ }
+ return next.clone().process()
+ // Ignore body
+ .doOnNext(HttpResponse::close)
+ .map(res -> res.getHeaderValue(WWW_AUTHENTICATE))
+ .map(header -> extractChallenge(header, BEARER_TOKEN_PREFIX))
+ .flatMap(map -> {
+ keyVaultCredential.setAuthorization(map.get("authorization"));
+ cache.setRequest(new TokenRequestContext().addScopes(map.get("resource") + "/.default"));
+ return cache.getToken();
+ })
+ .flatMap(token -> {
+ context.getHttpRequest().setHeader(AUTHORIZATION, BEARER_TOKEN_PREFIX + token.getToken());
+ return next.process();
+ });
+ }
+
+ /**
+ * Extracts the challenge off the authentication header.
+ *
+ * @param authenticateHeader The authentication header containing all the challenges.
+ * @param authChallengePrefix The authentication challenge name.
+ * @return a challenge map.
+ */
+ private static Map extractChallenge(String authenticateHeader, String authChallengePrefix) {
+ if (!isValidChallenge(authenticateHeader, authChallengePrefix)) {
+ return null;
+ }
+ authenticateHeader = authenticateHeader.toLowerCase(Locale.ROOT).replace(authChallengePrefix.toLowerCase(Locale.ROOT), "");
+
+ String[] challenges = authenticateHeader.split(", ");
+ Map challengeMap = new HashMap<>();
+ for (String pair : challenges) {
+ String[] keyValue = pair.split("=");
+ challengeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", ""));
+ }
+ return challengeMap;
+ }
+
+ /**
+ * Verifies whether a challenge is bearer or not.
+ *
+ * @param authenticateHeader The authentication header containing all the challenges.
+ * @param authChallengePrefix The authentication challenge name.
+ * @return A boolean indicating tha challenge is valid or not.
+ */
+ private static boolean isValidChallenge(String authenticateHeader, String authChallengePrefix) {
+ return (!CoreUtils.isNullOrEmpty(authenticateHeader)
+ && authenticateHeader.toLowerCase(Locale.ROOT).startsWith(authChallengePrefix.toLowerCase(Locale.ROOT)));
+ }
+}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java
new file mode 100644
index 000000000..fc97daa45
--- /dev/null
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultHttpPipelineBuilder.java
@@ -0,0 +1,79 @@
+/*
+ * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
+ * available under the terms of the MIT License. See the LICENSE file in the project root for more information.
+ */
+
+package com.microsoft.sqlserver.jdbc;
+
+import com.azure.core.http.HttpPipeline;
+import com.azure.core.http.HttpPipelineBuilder;
+import com.azure.core.http.policy.HttpLogOptions;
+import com.azure.core.http.policy.HttpLoggingPolicy;
+import com.azure.core.http.policy.HttpPipelinePolicy;
+import com.azure.core.http.policy.HttpPolicyProviders;
+import com.azure.core.http.policy.RetryPolicy;
+import com.azure.core.http.policy.UserAgentPolicy;
+import com.azure.core.util.Configuration;
+import com.azure.core.util.logging.ClientLogger;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+final class KeyVaultHttpPipelineBuilder {
+ public static final String APPLICATION_ID = "ms-sql-jdbc";
+ private static final String SDK_NAME = "azure-security-keyvault-keys";
+ private static final String SDK_VERSION = "4.2.0";
+
+ private final ClientLogger logger = new ClientLogger(KeyVaultHttpPipelineBuilder.class);
+
+ private final List policies;
+ private KeyVaultCredential credential;
+ private HttpLogOptions httpLogOptions;
+ private final RetryPolicy retryPolicy;
+
+ /**
+ * The constructor with defaults.
+ */
+ public KeyVaultHttpPipelineBuilder() {
+ retryPolicy = new RetryPolicy();
+ httpLogOptions = new HttpLogOptions();
+ policies = new ArrayList<>();
+ }
+
+ public HttpPipeline buildPipeline() {
+ Configuration buildConfiguration = Configuration.getGlobalConfiguration().clone();
+
+ if (credential == null) {
+ throw logger.logExceptionAsError(new IllegalStateException("Token Credential should be specified."));
+ }
+
+ // Closest to API goes first, closest to wire goes last.
+ final List policies = new ArrayList<>();
+
+ policies.add(new UserAgentPolicy(APPLICATION_ID, SDK_NAME, SDK_VERSION, buildConfiguration));
+ HttpPolicyProviders.addBeforeRetryPolicies(policies);
+ policies.add(retryPolicy);
+ policies.add(new KeyVaultCustomCredentialPolicy(credential));
+ policies.addAll(this.policies);
+ HttpPolicyProviders.addAfterRetryPolicies(policies);
+ policies.add(new HttpLoggingPolicy(httpLogOptions));
+
+ return new HttpPipelineBuilder()
+ .policies(policies.toArray(new HttpPipelinePolicy[0]))
+ .build();
+ }
+
+ /**
+ * Sets the credential to use when authenticating HTTP requests.
+ *
+ * @param credential The credential to use for authenticating HTTP requests.
+ * @return the updated KVHttpPipelineBuilder object.
+ * @throws NullPointerException if {@code credential} is {@code null}.
+ */
+ public KeyVaultHttpPipelineBuilder credential(KeyVaultCredential credential) {
+ Objects.requireNonNull(credential);
+ this.credential = credential;
+ return this;
+ }
+}
+
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java
index 90fbf756f..a575fc166 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java
@@ -7,6 +7,7 @@
import static java.nio.charset.StandardCharsets.UTF_16LE;
+import com.azure.core.http.HttpPipeline;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.URI;
@@ -19,28 +20,31 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
import java.util.Properties;
-import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
-import com.microsoft.azure.AzureResponseBuilder;
-import com.microsoft.azure.keyvault.KeyVaultClient;
-import com.microsoft.azure.keyvault.models.KeyBundle;
-import com.microsoft.azure.keyvault.models.KeyOperationResult;
-import com.microsoft.azure.keyvault.models.KeyVerifyResult;
-import com.microsoft.azure.keyvault.webkey.JsonWebKeyEncryptionAlgorithm;
-import com.microsoft.azure.keyvault.webkey.JsonWebKeySignatureAlgorithm;
-import com.microsoft.azure.serializer.AzureJacksonAdapter;
-import com.microsoft.rest.RestClient;
-
-import okhttp3.OkHttpClient;
-import retrofit2.Retrofit;
-
+import com.azure.core.credential.TokenCredential;
+import com.azure.identity.ManagedIdentityCredentialBuilder;
+import com.azure.security.keyvault.keys.KeyClient;
+import com.azure.security.keyvault.keys.KeyClientBuilder;
+import com.azure.security.keyvault.keys.cryptography.CryptographyClient;
+import com.azure.security.keyvault.keys.cryptography.CryptographyClientBuilder;
+import com.azure.security.keyvault.keys.cryptography.models.KeyWrapAlgorithm;
+import com.azure.security.keyvault.keys.cryptography.models.SignResult;
+import com.azure.security.keyvault.keys.cryptography.models.SignatureAlgorithm;
+import com.azure.security.keyvault.keys.cryptography.models.UnwrapResult;
+import com.azure.security.keyvault.keys.cryptography.models.VerifyResult;
+import com.azure.security.keyvault.keys.cryptography.models.WrapResult;
+import com.azure.security.keyvault.keys.models.KeyType;
+import com.azure.security.keyvault.keys.models.KeyVaultKey;
/**
* Provides implementation similar to certificate store provider. A CEK encrypted with certificate store provider should
* be decryptable by this provider and vice versa.
- *
+ *
* Envelope Format for the encrypted column encryption key version + keyPathLength + ciphertextLength + keyPath +
* ciphertext + signature version: A single byte indicating the format version. keyPathLength: Length of the keyPath.
* ciphertextLength: ciphertext length keyPath: keyPath used to encrypt the column encryption key. This is only used for
@@ -49,668 +53,675 @@
*/
public class SQLServerColumnEncryptionAzureKeyVaultProvider extends SQLServerColumnEncryptionKeyStoreProvider {
- private final static java.util.logging.Logger akvLogger = java.util.logging.Logger
- .getLogger("com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider");
- /**
- * Column Encryption Key Store Provider string
- */
- String name = "AZURE_KEY_VAULT";
-
- private final String baseUrl = "https://{vaultBaseUrl}";
-
- private static final String MSSQL_JDBC_PROPERTIES = "mssql-jdbc.properties";
- private static final String AKV_TRUSTED_ENDPOINTS_KEYWORD = "AKVTrustedEndpoints";
- private static final List akvTrustedEndpoints;
- static {
- akvTrustedEndpoints = getTrustedEndpoints();
- }
- private final String rsaEncryptionAlgorithmWithOAEPForAKV = "RSA-OAEP";
-
- /**
- * Algorithm version
- */
- private final byte[] firstVersion = new byte[] {0x01};
-
- private KeyVaultClient keyVaultClient;
-
- private KeyVaultCredential credentials;
-
- public void setName(String name) {
- this.name = name;
- }
-
- public String getName() {
- return this.name;
- }
-
- /**
- * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a client id and client key to authenticate to
- * AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault.
- *
- * @param clientId
- * Identifier of the client requesting the token.
- * @param clientKey
- * Key of the client requesting the token.
- * @throws SQLServerException
- * when an error occurs
- */
- public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId, String clientKey) throws SQLServerException {
- credentials = new KeyVaultCredential(clientId, clientKey);
- keyVaultClient = new KeyVaultClient(credentials);
- }
-
- /**
- * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a callback function to authenticate to AAD and
- * an executor service.. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault.
- *
- * This constructor is present to maintain backwards compatibility with 6.0 version of the driver. Deprecated for
- * removal in next stable release.
- *
- * @param authenticationCallback
- * - Callback function used for authenticating to AAD.
- * @param executorService
- * - The ExecutorService, previously used to create the keyVaultClient, but not in use anymore. - This
- * parameter can be passed as 'null'
- * @throws SQLServerException
- * when an error occurs
- */
- @Deprecated
- public SQLServerColumnEncryptionAzureKeyVaultProvider(
- SQLServerKeyVaultAuthenticationCallback authenticationCallback,
- ExecutorService executorService) throws SQLServerException {
- this(authenticationCallback);
- }
-
- /**
- * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a callback function to authenticate to AAD. This
- * is used by KeyVaultClient at runtime to authenticate to Azure Key Vault.
- *
- * @param authenticationCallback
- * - Callback function used for authenticating to AAD.
- * @throws SQLServerException
- * when an error occurs
- */
- public SQLServerColumnEncryptionAzureKeyVaultProvider(
- SQLServerKeyVaultAuthenticationCallback authenticationCallback) throws SQLServerException {
- if (null == authenticationCallback) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
- Object[] msgArgs1 = {"SQLServerKeyVaultAuthenticationCallback"};
- throw new SQLServerException(form.format(msgArgs1), null);
+ private final static java.util.logging.Logger akvLogger = java.util.logging.Logger
+ .getLogger("com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider");
+ private HttpPipeline keyVaultPipeline;
+ private KeyVaultCredential keyVaultCredential;
+ /**
+ * Column Encryption Key Store Provider string
+ */
+ String name = "AZURE_KEY_VAULT";
+
+ private static final String MSSQL_JDBC_PROPERTIES = "mssql-jdbc.properties";
+ private static final String AKV_TRUSTED_ENDPOINTS_KEYWORD = "AKVTrustedEndpoints";
+ private static final String RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV = "RSA-OAEP";
+
+ private static final List akvTrustedEndpoints;
+ /**
+ * Algorithm version
+ */
+ private final byte[] firstVersion = new byte[] {0x01};
+
+ private Map cachedKeyClients = new ConcurrentHashMap<>();
+ private Map cachedCryptographyClients = new ConcurrentHashMap<>();
+ private TokenCredential credential;
+
+ static {
+ akvTrustedEndpoints = getTrustedEndpoints();
}
- credentials = new KeyVaultCredential(authenticationCallback);
- RestClient restClient = new RestClient.Builder(new OkHttpClient.Builder(), new Retrofit.Builder())
- .withBaseUrl(baseUrl).withCredentials(credentials).withSerializerAdapter(new AzureJacksonAdapter())
- .withResponseBuilderFactory(new AzureResponseBuilder.Factory()).build();
- keyVaultClient = new KeyVaultClient(restClient);
- }
-
- /**
- * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by
- * KeyVaultClient at runtime to authenticate to Azure Key Vault.
- *
- * @throws SQLServerException
- * when an error occurs
- */
- SQLServerColumnEncryptionAzureKeyVaultProvider() throws SQLServerException {
- credentials = new KeyVaultCredential();
- keyVaultClient = new KeyVaultClient(credentials);
- }
-
- /**
- * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by
- * KeyVaultClient at runtime to authenticate to Azure Key Vault.
- *
- * @param clientId
- * Identifier of the client requesting the token.
- *
- * @throws SQLServerException
- * when an error occurs
- */
- SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId) throws SQLServerException {
- credentials = new KeyVaultCredential(clientId);
- keyVaultClient = new KeyVaultClient(credentials);
- }
-
- /**
- * Decrypts an encrypted CEK with RSA encryption algorithm using the asymmetric key specified by the key path
- *
- * @param masterKeyPath
- * - Complete path of an asymmetric key in AKV
- * @param encryptionAlgorithm
- * - Asymmetric Key Encryption Algorithm
- * @param encryptedColumnEncryptionKey
- * - Encrypted Column Encryption Key
- * @return Plain text column encryption key
- */
- @Override
- public byte[] decryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
- byte[] encryptedColumnEncryptionKey) throws SQLServerException {
-
- // Validate the input parameters
- this.ValidateNonEmptyAKVPath(masterKeyPath);
-
- if (null == encryptedColumnEncryptionKey) {
- throw new SQLServerException(SQLServerException.getErrString("R_NullEncryptedColumnEncryptionKey"), null);
+
+ public void setName(String name) {
+ this.name = name;
}
- if (0 == encryptedColumnEncryptionKey.length) {
- throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedColumnEncryptionKey"), null);
+ public String getName() {
+ return this.name;
}
- // Validate encryptionAlgorithm
- encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm);
-
- // Validate whether the key is RSA one or not and then get the key size
- int keySizeInBytes = getAKVKeySize(masterKeyPath);
-
- // Validate and decrypt the EncryptedColumnEncryptionKey
- // Format is
- // version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature
- //
- // keyPath is present in the encrypted column encryption key for identifying the original source of the
- // asymmetric key pair and
- // we will not validate it against the data contained in the CMK metadata (masterKeyPath).
-
- // Validate the version byte
- if (encryptedColumnEncryptionKey[0] != firstVersion[0]) {
- MessageFormat form = new MessageFormat(
- SQLServerException.getErrString("R_InvalidEcryptionAlgorithmVersion"));
- Object[] msgArgs = {String.format("%02X ", encryptedColumnEncryptionKey[0]),
- String.format("%02X ", firstVersion[0])};
- throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId, String clientKey)
+ throws SQLServerException {
+ if (clientId == null || clientId.isEmpty()) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
+ Object[] msgArgs1 = {"Client ID"};
+ throw new SQLServerException(form.format(msgArgs1), null);
+ }
+ if (clientKey == null || clientKey.isEmpty()) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
+ Object[] msgArgs1 = {"Client Key"};
+ throw new SQLServerException(form.format(msgArgs1), null);
+ }
+
+ keyVaultCredential = new KeyVaultCredential(clientId, clientKey);
+ keyVaultPipeline = new KeyVaultHttpPipelineBuilder().credential(keyVaultCredential)
+ .buildPipeline();
}
- // Get key path length
- int currentIndex = firstVersion.length;
- short keyPathLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
- // We just read 2 bytes
- currentIndex += 2;
-
- // Get ciphertext length
- short cipherTextLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
- currentIndex += 2;
-
- // Skip KeyPath
- // KeyPath exists only for troubleshooting purposes and doesnt need validation.
- currentIndex += keyPathLength;
-
- // validate the ciphertext length
- if (cipherTextLength != keySizeInBytes) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyLengthError"));
- Object[] msgArgs = {cipherTextLength, keySizeInBytes, masterKeyPath};
- throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ /**
+ * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by
+ * KeyVaultClient at runtime to authenticate to Azure Key Vault.
+ */
+ SQLServerColumnEncryptionAzureKeyVaultProvider() throws SQLServerException {
+ createKeyvaultClients(new ManagedIdentityCredentialBuilder().build());
}
- // Validate the signature length
- int signatureLength = encryptedColumnEncryptionKey.length - currentIndex - cipherTextLength;
+ /**
+ * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by
+ * KeyVaultClient at runtime to authenticate to Azure Key Vault.
+ *
+ * @param clientId Identifier of the client requesting the token.
+ */
+ SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId) throws SQLServerException {
+ if (clientId == null || clientId.isEmpty()) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
+ Object[] msgArgs1 = {"Client ID"};
+ throw new SQLServerException(form.format(msgArgs1), null);
+ }
+ createKeyvaultClients(new ManagedIdentityCredentialBuilder().clientId(clientId).build());
+ }
- if (signatureLength != keySizeInBytes) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVSignatureLengthError"));
- Object[] msgArgs = {signatureLength, keySizeInBytes, masterKeyPath};
- throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ /**
+ * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider using the provided TokenCredential to
+ * authenticate to AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault.
+ *
+ * @param tokenCredential The TokenCredential to use to authenticate to Azure Key Vault.
+ */
+ public SQLServerColumnEncryptionAzureKeyVaultProvider(TokenCredential tokenCredential)
+ throws SQLServerException {
+ createKeyvaultClients(tokenCredential);
}
- // Get ciphertext
- byte[] cipherText = new byte[cipherTextLength];
- System.arraycopy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);
- currentIndex += cipherTextLength;
+ private void createKeyvaultClients(TokenCredential credential) throws SQLServerException {
+ this.credential = Objects.requireNonNull(credential);
+ }
- // Get signature
- byte[] signature = new byte[signatureLength];
- System.arraycopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength);
+ /**
+ * Decrypts an encrypted CEK with RSA encryption algorithm using the asymmetric key specified by the key path
+ *
+ * @param masterKeyPath - Complete path of an asymmetric key in AKV
+ * @param encryptionAlgorithm - Asymmetric Key Encryption Algorithm
+ * @param encryptedColumnEncryptionKey - Encrypted Column Encryption Key
+ * @return Plain text column encryption key
+ */
+ @Override public byte[] decryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
+ byte[] encryptedColumnEncryptionKey) throws SQLServerException {
- // Compute the hash to validate the signature
- byte[] hash = new byte[encryptedColumnEncryptionKey.length - signature.length];
+ // Validate the input parameters
+ this.ValidateNonEmptyAKVPath(masterKeyPath);
- System.arraycopy(encryptedColumnEncryptionKey, 0, hash, 0,
- encryptedColumnEncryptionKey.length - signature.length);
+ if (null == encryptedColumnEncryptionKey) {
+ throw new SQLServerException(
+ SQLServerException.getErrString("R_NullEncryptedColumnEncryptionKey"), null);
+ }
- MessageDigest md = null;
- try {
- md = MessageDigest.getInstance("SHA-256");
- } catch (NoSuchAlgorithmException e) {
- throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
- }
- md.update(hash);
- byte dataToVerify[] = md.digest();
+ if (0 == encryptedColumnEncryptionKey.length) {
+ throw new SQLServerException(
+ SQLServerException.getErrString("R_EmptyEncryptedColumnEncryptionKey"), null);
+ }
- if (null == dataToVerify) {
- throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null);
- }
+ // Validate encryptionAlgorithm
+ KeyWrapAlgorithm _encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm);
+
+ // Validate whether the key is RSA one or not and then get the key size
+ int keySizeInBytes = getAKVKeySize(masterKeyPath);
+
+ // Validate and decrypt the EncryptedColumnEncryptionKey
+ // Format is
+ // version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature
+ //
+ // keyPath is present in the encrypted column encryption key for identifying the original source of the
+ // asymmetric key pair and
+ // we will not validate it against the data contained in the CMK metadata (masterKeyPath).
+
+ // Validate the version byte
+ if (encryptedColumnEncryptionKey[0] != firstVersion[0]) {
+ MessageFormat form = new MessageFormat(
+ SQLServerException.getErrString("R_InvalidEcryptionAlgorithmVersion"));
+ Object[] msgArgs = {String.format("%02X ", encryptedColumnEncryptionKey[0]),
+ String.format("%02X ", firstVersion[0])};
+ throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ }
- // Validate the signature
- if (!AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath)) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_CEKSignatureNotMatchCMK"));
- Object[] msgArgs = {masterKeyPath};
- throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
- }
+ // Get key path length
+ int currentIndex = firstVersion.length;
+ short keyPathLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
+ // We just read 2 bytes
+ currentIndex += 2;
+
+ // Get ciphertext length
+ short cipherTextLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
+ currentIndex += 2;
+
+ // Skip KeyPath
+ // KeyPath exists only for troubleshooting purposes and doesnt need validation.
+ currentIndex += keyPathLength;
+
+ // validate the ciphertext length
+ if (cipherTextLength != keySizeInBytes) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyLengthError"));
+ Object[] msgArgs = {cipherTextLength, keySizeInBytes, masterKeyPath};
+ throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ }
- // Decrypt the CEK
- byte[] decryptedCEK = this.AzureKeyVaultUnWrap(masterKeyPath, encryptionAlgorithm, cipherText);
+ // Validate the signature length
+ int signatureLength = encryptedColumnEncryptionKey.length - currentIndex - cipherTextLength;
- return decryptedCEK;
- }
+ if (signatureLength != keySizeInBytes) {
+ MessageFormat form = new MessageFormat(
+ SQLServerException.getErrString("R_AKVSignatureLengthError"));
+ Object[] msgArgs = {signatureLength, keySizeInBytes, masterKeyPath};
+ throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ }
- private short convertTwoBytesToShort(byte[] input, int index) throws SQLServerException {
+ // Get ciphertext
+ byte[] cipherText = new byte[cipherTextLength];
+ System.arraycopy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);
+ currentIndex += cipherTextLength;
- short shortVal;
- if (index + 1 >= input.length) {
- throw new SQLServerException(null, SQLServerException.getErrString("R_ByteToShortConversion"), null, 0,
- false);
- }
- ByteBuffer byteBuffer = ByteBuffer.allocate(2);
- byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
- byteBuffer.put(input[index]);
- byteBuffer.put(input[index + 1]);
- shortVal = byteBuffer.getShort(0);
- return shortVal;
-
- }
-
- /**
- * Encrypts CEK with RSA encryption algorithm using the asymmetric key specified by the key path.
- *
- * @param masterKeyPath
- * - Complete path of an asymmetric key in AKV
- * @param encryptionAlgorithm
- * - Asymmetric Key Encryption Algorithm
- * @param columnEncryptionKey
- * - Plain text column encryption key
- * @return Encrypted column encryption key
- */
- @Override
- public byte[] encryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
- byte[] columnEncryptionKey) throws SQLServerException {
-
- // Validate the input parameters
- this.ValidateNonEmptyAKVPath(masterKeyPath);
-
- if (null == columnEncryptionKey) {
- throw new SQLServerException(SQLServerException.getErrString("R_NullColumnEncryptionKey"), null);
+ // Get signature
+ byte[] signature = new byte[signatureLength];
+ System.arraycopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength);
+
+ // Compute the hash to validate the signature
+ byte[] hash = new byte[encryptedColumnEncryptionKey.length - signature.length];
+
+ System.arraycopy(encryptedColumnEncryptionKey, 0, hash, 0,
+ encryptedColumnEncryptionKey.length - signature.length);
+
+ MessageDigest md = null;
+ try {
+ md = MessageDigest.getInstance("SHA-256");
+ } catch (NoSuchAlgorithmException e) {
+ throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
+ }
+ md.update(hash);
+ byte dataToVerify[] = md.digest();
+
+ if (null == dataToVerify) {
+ throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null);
+ }
+
+ // Validate the signature
+ if (!AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath)) {
+ MessageFormat form = new MessageFormat(
+ SQLServerException.getErrString("R_CEKSignatureNotMatchCMK"));
+ Object[] msgArgs = {masterKeyPath};
+ throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ }
+
+ // Decrypt the CEK
+ byte[] decryptedCEK = this.AzureKeyVaultUnWrap(masterKeyPath, _encryptionAlgorithm, cipherText);
+
+ return decryptedCEK;
}
- if (0 == columnEncryptionKey.length) {
- throw new SQLServerException(SQLServerException.getErrString("R_EmptyCEK"), null);
+ private short convertTwoBytesToShort(byte[] input, int index) throws SQLServerException {
+
+ short shortVal;
+ if (index + 1 >= input.length) {
+ throw new SQLServerException(null, SQLServerException.getErrString("R_ByteToShortConversion"),
+ null, 0, false);
+ }
+ ByteBuffer byteBuffer = ByteBuffer.allocate(2);
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
+ byteBuffer.put(input[index]);
+ byteBuffer.put(input[index + 1]);
+ shortVal = byteBuffer.getShort(0);
+ return shortVal;
+
}
- // Validate encryptionAlgorithm
- encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm);
+ /**
+ * Encrypts CEK with RSA encryption algorithm using the asymmetric key specified by the key path.
+ *
+ * @param masterKeyPath - Complete path of an asymmetric key in AKV
+ * @param encryptionAlgorithm - Asymmetric Key Encryption Algorithm
+ * @param columnEncryptionKey - Plain text column encryption key
+ * @return Encrypted column encryption key
+ */
+ @Override public byte[] encryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
+ byte[] columnEncryptionKey) throws SQLServerException {
+
+ // Validate the input parameters
+ this.ValidateNonEmptyAKVPath(masterKeyPath);
- // Validate whether the key is RSA one or not and then get the key size
- int keySizeInBytes = getAKVKeySize(masterKeyPath);
+ if (null == columnEncryptionKey) {
+ throw new SQLServerException(SQLServerException.getErrString("R_NullColumnEncryptionKey"),
+ null);
+ }
- // Construct the encryptedColumnEncryptionKey
- // Format is
- // version + keyPathLength + ciphertextLength + ciphertext + keyPath + signature
- //
- // We currently only support one version
- byte[] version = new byte[] {firstVersion[0]};
+ if (0 == columnEncryptionKey.length) {
+ throw new SQLServerException(SQLServerException.getErrString("R_EmptyCEK"), null);
+ }
- // Get the Unicode encoded bytes of cultureinvariant lower case masterKeyPath
- byte[] masterKeyPathBytes = masterKeyPath.toLowerCase(Locale.ENGLISH).getBytes(UTF_16LE);
+ // Validate encryptionAlgorithm
+ KeyWrapAlgorithm _encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm);
- byte[] keyPathLength = new byte[2];
- keyPathLength[0] = (byte) (((short) masterKeyPathBytes.length) & 0xff);
- keyPathLength[1] = (byte) (((short) masterKeyPathBytes.length) >> 8 & 0xff);
+ // Validate whether the key is RSA one or not and then get the key size
+ int keySizeInBytes = getAKVKeySize(masterKeyPath);
- // Encrypt the plain text
- byte[] cipherText = this.AzureKeyVaultWrap(masterKeyPath, encryptionAlgorithm, columnEncryptionKey);
+ // Construct the encryptedColumnEncryptionKey
+ // Format is
+ // version + keyPathLength + ciphertextLength + ciphertext + keyPath + signature
+ //
+ // We currently only support one version
+ byte[] version = new byte[] {firstVersion[0]};
- byte[] cipherTextLength = new byte[2];
- cipherTextLength[0] = (byte) (((short) cipherText.length) & 0xff);
- cipherTextLength[1] = (byte) (((short) cipherText.length) >> 8 & 0xff);
+ // Get the Unicode encoded bytes of cultureinvariant lower case masterKeyPath
+ byte[] masterKeyPathBytes = masterKeyPath.toLowerCase(Locale.ENGLISH).getBytes(UTF_16LE);
- if (cipherText.length != keySizeInBytes) {
- throw new SQLServerException(SQLServerException.getErrString("R_CipherTextLengthNotMatchRSASize"), null);
- }
+ byte[] keyPathLength = new byte[2];
+ keyPathLength[0] = (byte) (((short) masterKeyPathBytes.length) & 0xff);
+ keyPathLength[1] = (byte) (((short) masterKeyPathBytes.length) >> 8 & 0xff);
- // Compute hash
- // SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext)
- byte[] dataToHash = new byte[version.length + keyPathLength.length + cipherTextLength.length
- + masterKeyPathBytes.length + cipherText.length];
- int destinationPosition = version.length;
- System.arraycopy(version, 0, dataToHash, 0, version.length);
+ // Encrypt the plain text
+ byte[] cipherText = this.AzureKeyVaultWrap(masterKeyPath, _encryptionAlgorithm, columnEncryptionKey);
- System.arraycopy(keyPathLength, 0, dataToHash, destinationPosition, keyPathLength.length);
- destinationPosition += keyPathLength.length;
+ byte[] cipherTextLength = new byte[2];
+ cipherTextLength[0] = (byte) (((short) cipherText.length) & 0xff);
+ cipherTextLength[1] = (byte) (((short) cipherText.length) >> 8 & 0xff);
- System.arraycopy(cipherTextLength, 0, dataToHash, destinationPosition, cipherTextLength.length);
- destinationPosition += cipherTextLength.length;
+ if (cipherText.length != keySizeInBytes) {
+ throw new SQLServerException(
+ SQLServerException.getErrString("R_CipherTextLengthNotMatchRSASize"), null);
+ }
- System.arraycopy(masterKeyPathBytes, 0, dataToHash, destinationPosition, masterKeyPathBytes.length);
- destinationPosition += masterKeyPathBytes.length;
+ // Compute hash
+ // SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext)
+ byte[] dataToHash = new byte[version.length + keyPathLength.length + cipherTextLength.length
+ + masterKeyPathBytes.length + cipherText.length];
+ int destinationPosition = version.length;
+ System.arraycopy(version, 0, dataToHash, 0, version.length);
- System.arraycopy(cipherText, 0, dataToHash, destinationPosition, cipherText.length);
+ System.arraycopy(keyPathLength, 0, dataToHash, destinationPosition, keyPathLength.length);
+ destinationPosition += keyPathLength.length;
- MessageDigest md = null;
- try {
- md = MessageDigest.getInstance("SHA-256");
- } catch (NoSuchAlgorithmException e) {
- throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
- }
- md.update(dataToHash);
- byte dataToSign[] = md.digest();
+ System.arraycopy(cipherTextLength, 0, dataToHash, destinationPosition, cipherTextLength.length);
+ destinationPosition += cipherTextLength.length;
- // Sign the hash
- byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);
+ System.arraycopy(masterKeyPathBytes, 0, dataToHash, destinationPosition, masterKeyPathBytes.length);
+ destinationPosition += masterKeyPathBytes.length;
- if (signedHash.length != keySizeInBytes) {
- throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null);
- }
+ System.arraycopy(cipherText, 0, dataToHash, destinationPosition, cipherText.length);
- if (!this.AzureKeyVaultVerifySignature(dataToSign, signedHash, masterKeyPath)) {
- throw new SQLServerException(SQLServerException.getErrString("R_InvalidSignatureComputed"), null);
- }
+ MessageDigest md = null;
+ try {
+ md = MessageDigest.getInstance("SHA-256");
+ } catch (NoSuchAlgorithmException e) {
+ throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
+ }
+ md.update(dataToHash);
+ byte dataToSign[] = md.digest();
+
+ // Sign the hash
+ byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);
+
+ if (signedHash.length != keySizeInBytes) {
+ throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null);
+ }
+
+ if (!this.AzureKeyVaultVerifySignature(dataToSign, signedHash, masterKeyPath)) {
+ throw new SQLServerException(SQLServerException.getErrString("R_InvalidSignatureComputed"),
+ null);
+ }
+
+ // Construct the encrypted column encryption key
+ // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature
+ int encryptedColumnEncryptionKeyLength =
+ version.length + cipherTextLength.length + keyPathLength.length + cipherText.length
+ + masterKeyPathBytes.length + signedHash.length;
+ byte[] encryptedColumnEncryptionKey = new byte[encryptedColumnEncryptionKeyLength];
- // Construct the encrypted column encryption key
- // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature
- int encryptedColumnEncryptionKeyLength = version.length + cipherTextLength.length + keyPathLength.length
- + cipherText.length + masterKeyPathBytes.length + signedHash.length;
- byte[] encryptedColumnEncryptionKey = new byte[encryptedColumnEncryptionKeyLength];
-
- // Copy version byte
- int currentIndex = 0;
- System.arraycopy(version, 0, encryptedColumnEncryptionKey, currentIndex, version.length);
- currentIndex += version.length;
-
- // Copy key path length
- System.arraycopy(keyPathLength, 0, encryptedColumnEncryptionKey, currentIndex, keyPathLength.length);
- currentIndex += keyPathLength.length;
-
- // Copy ciphertext length
- System.arraycopy(cipherTextLength, 0, encryptedColumnEncryptionKey, currentIndex, cipherTextLength.length);
- currentIndex += cipherTextLength.length;
-
- // Copy key path
- System.arraycopy(masterKeyPathBytes, 0, encryptedColumnEncryptionKey, currentIndex, masterKeyPathBytes.length);
- currentIndex += masterKeyPathBytes.length;
-
- // Copy ciphertext
- System.arraycopy(cipherText, 0, encryptedColumnEncryptionKey, currentIndex, cipherText.length);
- currentIndex += cipherText.length;
-
- // copy the signature
- System.arraycopy(signedHash, 0, encryptedColumnEncryptionKey, currentIndex, signedHash.length);
-
- return encryptedColumnEncryptionKey;
- }
-
- /**
- * Validates that the encryption algorithm is RSA_OAEP and if it is not, then throws an exception.
- *
- * @param encryptionAlgorithm
- * - Asymmetric key encryptio algorithm
- * @return The encryption algorithm that is going to be used.
- * @throws SQLServerException
- */
- private String validateEncryptionAlgorithm(String encryptionAlgorithm) throws SQLServerException {
-
- if (null == encryptionAlgorithm) {
- throw new SQLServerException(null, SQLServerException.getErrString("R_NullKeyEncryptionAlgorithm"), null, 0,
- false);
+ // Copy version byte
+ int currentIndex = 0;
+ System.arraycopy(version, 0, encryptedColumnEncryptionKey, currentIndex, version.length);
+ currentIndex += version.length;
+
+ // Copy key path length
+ System.arraycopy(keyPathLength, 0, encryptedColumnEncryptionKey, currentIndex, keyPathLength.length);
+ currentIndex += keyPathLength.length;
+
+ // Copy ciphertext length
+ System.arraycopy(cipherTextLength, 0, encryptedColumnEncryptionKey, currentIndex,
+ cipherTextLength.length);
+ currentIndex += cipherTextLength.length;
+
+ // Copy key path
+ System.arraycopy(masterKeyPathBytes, 0, encryptedColumnEncryptionKey, currentIndex,
+ masterKeyPathBytes.length);
+ currentIndex += masterKeyPathBytes.length;
+
+ // Copy ciphertext
+ System.arraycopy(cipherText, 0, encryptedColumnEncryptionKey, currentIndex, cipherText.length);
+ currentIndex += cipherText.length;
+
+ // copy the signature
+ System.arraycopy(signedHash, 0, encryptedColumnEncryptionKey, currentIndex, signedHash.length);
+
+ return encryptedColumnEncryptionKey;
}
- // Transform to standard format (dash instead of underscore) to support both "RSA_OAEP" and "RSA-OAEP"
- if ("RSA_OAEP".equalsIgnoreCase(encryptionAlgorithm)) {
- encryptionAlgorithm = "RSA-OAEP";
+ /**
+ * Validates that the encryption algorithm is RSA_OAEP and if it is not, then throws an exception.
+ *
+ * @param encryptionAlgorithm - Asymmetric key encryptio algorithm
+ * @return The encryption algorithm that is going to be used.
+ * @throws SQLServerException
+ */
+ private KeyWrapAlgorithm validateEncryptionAlgorithm(String encryptionAlgorithm) throws SQLServerException {
+
+ if (null == encryptionAlgorithm) {
+ throw new SQLServerException(null,
+ SQLServerException.getErrString("R_NullKeyEncryptionAlgorithm"), null, 0, false);
+ }
+
+ // Transform to standard format (dash instead of underscore) to support enum lookup
+ if ("RSA_OAEP".equalsIgnoreCase(encryptionAlgorithm)) {
+ encryptionAlgorithm = RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV;
+ }
+
+ if (!RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV.equalsIgnoreCase(encryptionAlgorithm.trim())) {
+ MessageFormat form = new MessageFormat(
+ SQLServerException.getErrString("R_InvalidKeyEncryptionAlgorithm"));
+ Object[] msgArgs = {encryptionAlgorithm, RSA_ENCRYPTION_ALGORITHM_WITH_OAEP_FOR_AKV};
+ throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ }
+
+ return KeyWrapAlgorithm.fromString(encryptionAlgorithm);
}
- if (!rsaEncryptionAlgorithmWithOAEPForAKV.equalsIgnoreCase(encryptionAlgorithm.trim())) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidKeyEncryptionAlgorithm"));
- Object[] msgArgs = {encryptionAlgorithm, rsaEncryptionAlgorithmWithOAEPForAKV};
- throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
+ /**
+ * Checks if the Azure Key Vault key path is Empty or Null (and raises exception if they are).
+ *
+ * @param masterKeyPath
+ * @throws SQLServerException
+ */
+ private void ValidateNonEmptyAKVPath(String masterKeyPath) throws SQLServerException {
+ // throw appropriate error if masterKeyPath is null or empty
+ if (null == masterKeyPath || masterKeyPath.trim().isEmpty()) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVPathNull"));
+ Object[] msgArgs = {masterKeyPath};
+ throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ } else {
+ URI parsedUri = null;
+ try {
+ parsedUri = new URI(masterKeyPath);
+
+ // A valid URI.
+ // Check if it is pointing to a trusted endpoint.
+ String host = parsedUri.getHost();
+ if (null != host) {
+ host = host.toLowerCase(Locale.ENGLISH);
+ }
+ for (final String endpoint : akvTrustedEndpoints) {
+ if (null != host && host.endsWith(endpoint)) {
+ return;
+ }
+ }
+ } catch (URISyntaxException e) {
+ MessageFormat form = new MessageFormat(
+ SQLServerException.getErrString("R_AKVURLInvalid"));
+ Object[] msgArgs = {masterKeyPath};
+ throw new SQLServerException(form.format(msgArgs), null, 0, e);
+ }
+
+ MessageFormat form = new MessageFormat(
+ SQLServerException.getErrString("R_AKVMasterKeyPathInvalid"));
+ Object[] msgArgs = {masterKeyPath};
+ throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ }
}
- return encryptionAlgorithm;
- }
-
- /**
- * Checks if the Azure Key Vault key path is Empty or Null (and raises exception if they are).
- *
- * @param masterKeyPath
- * @throws SQLServerException
- */
- private void ValidateNonEmptyAKVPath(String masterKeyPath) throws SQLServerException {
- // throw appropriate error if masterKeyPath is null or empty
- if (null == masterKeyPath || masterKeyPath.trim().isEmpty()) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVPathNull"));
- Object[] msgArgs = {masterKeyPath};
- throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
- } else {
- URI parsedUri = null;
- try {
- parsedUri = new URI(masterKeyPath);
-
- // A valid URI.
- // Check if it is pointing to a trusted endpoint.
- String host = parsedUri.getHost();
- if (null != host) {
- host = host.toLowerCase(Locale.ENGLISH);
- }
- for (final String endpoint : akvTrustedEndpoints) {
- if (null != host && host.endsWith(endpoint)) {
- return;
- }
- }
- } catch (URISyntaxException e) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVURLInvalid"));
- Object[] msgArgs = {masterKeyPath};
- throw new SQLServerException(form.format(msgArgs), null, 0, e);
- }
+ /**
+ * Encrypts the text using specified Azure Key Vault key.
+ *
+ * @param masterKeyPath - Azure Key Vault key url.
+ * @param encryptionAlgorithm - Encryption Algorithm.
+ * @param columnEncryptionKey - Plain text Column Encryption Key.
+ * @return Returns an encrypted blob or throws an exception if there are any errors.
+ * @throws SQLServerException
+ */
+ private byte[] AzureKeyVaultWrap(String masterKeyPath, KeyWrapAlgorithm encryptionAlgorithm,
+ byte[] columnEncryptionKey) throws SQLServerException {
+ if (null == columnEncryptionKey) {
+ throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null);
+ }
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVMasterKeyPathInvalid"));
- Object[] msgArgs = {masterKeyPath};
- throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath);
+ WrapResult wrappedKey = cryptoClient.wrapKey(KeyWrapAlgorithm.RSA_OAEP, columnEncryptionKey);
+ return wrappedKey.getEncryptedKey();
}
- }
-
- /**
- * Encrypts the text using specified Azure Key Vault key.
- *
- * @param masterKeyPath
- * - Azure Key Vault key url.
- * @param encryptionAlgorithm
- * - Encryption Algorithm.
- * @param columnEncryptionKey
- * - Plain text Column Encryption Key.
- * @return Returns an encrypted blob or throws an exception if there are any errors.
- * @throws SQLServerException
- */
- private byte[] AzureKeyVaultWrap(String masterKeyPath, String encryptionAlgorithm,
- byte[] columnEncryptionKey) throws SQLServerException {
- if (null == columnEncryptionKey) {
- throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null);
+
+ /**
+ * Encrypts the text using specified Azure Key Vault key.
+ *
+ * @param masterKeyPath - Azure Key Vault key url.
+ * @param encryptionAlgorithm - Encrypted Column Encryption Key.
+ * @param encryptedColumnEncryptionKey - Encrypted Column Encryption Key.
+ * @return Returns the decrypted plaintext Column Encryption Key or throws an exception if there are any errors.
+ * @throws SQLServerException
+ */
+ private byte[] AzureKeyVaultUnWrap(String masterKeyPath, KeyWrapAlgorithm encryptionAlgorithm,
+ byte[] encryptedColumnEncryptionKey) throws SQLServerException {
+ if (null == encryptedColumnEncryptionKey) {
+ throw new SQLServerException(SQLServerException.getErrString("R_EncryptedCEKNull"), null);
+ }
+
+ if (0 == encryptedColumnEncryptionKey.length) {
+ throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null);
+ }
+
+ CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath);
+
+ UnwrapResult unwrappedKey = cryptoClient.unwrapKey(encryptionAlgorithm, encryptedColumnEncryptionKey);
+
+ return unwrappedKey.getKey();
}
- JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm);
- KeyOperationResult wrappedKey = keyVaultClient.wrapKey(masterKeyPath, jsonEncryptionAlgorithm,
- columnEncryptionKey);
-
- return wrappedKey.result();
- }
-
- /**
- * Encrypts the text using specified Azure Key Vault key.
- *
- * @param masterKeyPath
- * - Azure Key Vault key url.
- * @param encryptionAlgorithm
- * - Encrypted Column Encryption Key.
- * @param encryptedColumnEncryptionKey
- * - Encrypted Column Encryption Key.
- * @return Returns the decrypted plaintext Column Encryption Key or throws an exception if there are any errors.
- * @throws SQLServerException
- */
- private byte[] AzureKeyVaultUnWrap(String masterKeyPath, String encryptionAlgorithm,
- byte[] encryptedColumnEncryptionKey) throws SQLServerException {
- if (null == encryptedColumnEncryptionKey) {
- throw new SQLServerException(SQLServerException.getErrString("R_EncryptedCEKNull"), null);
+ private CryptographyClient getCryptographyClient(String masterKeyPath) throws SQLServerException {
+ if (this.cachedCryptographyClients.containsKey(masterKeyPath)) {
+ return cachedCryptographyClients.get(masterKeyPath);
+ }
+
+ KeyVaultKey retrievedKey = getKeyVaultKey(masterKeyPath);
+
+ CryptographyClient cryptoClient;
+ if (credential != null) {
+ cryptoClient = new CryptographyClientBuilder().credential(credential)
+ .keyIdentifier(retrievedKey.getId()).buildClient();
+ } else {
+ cryptoClient = new CryptographyClientBuilder().pipeline(keyVaultPipeline)
+ .keyIdentifier(retrievedKey.getId()).buildClient();
+ }
+ cachedCryptographyClients.putIfAbsent(masterKeyPath, cryptoClient);
+ return cachedCryptographyClients.get(masterKeyPath);
}
- if (0 == encryptedColumnEncryptionKey.length) {
- throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null);
+ /**
+ * Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
+ *
+ * @param dataToSign - Text to sign.
+ * @param masterKeyPath - Azure Key Vault key url.
+ * @return Signature
+ * @throws SQLServerException
+ */
+ private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign, String masterKeyPath) throws SQLServerException {
+ assert ((null != dataToSign) && (0 != dataToSign.length));
+
+ CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath);
+ SignResult signedData = cryptoClient.sign(SignatureAlgorithm.RS256, dataToSign);
+ return signedData.getSignature();
}
- JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm);
- KeyOperationResult unwrappedKey = keyVaultClient.unwrapKey(masterKeyPath, jsonEncryptionAlgorithm,
- encryptedColumnEncryptionKey);
-
- return unwrappedKey.result();
- }
-
- /**
- * Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
- *
- * @param dataToSign
- * - Text to sign.
- * @param masterKeyPath
- * - Azure Key Vault key url.
- * @return Signature
- * @throws SQLServerException
- */
- private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign, String masterKeyPath) throws SQLServerException {
- assert ((null != dataToSign) && (0 != dataToSign.length));
-
- KeyOperationResult signedData = keyVaultClient.sign(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256,
- dataToSign);
-
- return signedData.result();
- }
-
- /**
- * Verifies the given RSA PKCSv1.5 signature.
- *
- * @param dataToVerify
- * @param signature
- * @param masterKeyPath
- * - Azure Key Vault key url.
- * @return true if signature is valid, false if it is not valid
- * @throws SQLServerException
- */
- private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify, byte[] signature,
- String masterKeyPath) throws SQLServerException {
- assert ((null != dataToVerify) && (0 != dataToVerify.length));
- assert ((null != signature) && (0 != signature.length));
-
- KeyVerifyResult valid = keyVaultClient.verify(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify,
- signature);
-
- return valid.value();
- }
-
- /**
- * Returns the public Key size in bytes.
- *
- * @param masterKeyPath
- * - Azure Key Vault Key path
- * @return Key size in bytes
- * @throws SQLServerException
- * when an error occurs
- */
- private int getAKVKeySize(String masterKeyPath) throws SQLServerException {
- KeyBundle retrievedKey = keyVaultClient.getKey(masterKeyPath);
-
- if (null == retrievedKey) {
- String[] keyTokens = masterKeyPath.split("/");
-
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyNotFound"));
- Object[] msgArgs = {keyTokens[keyTokens.length - 1]};
- throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ /**
+ * Verifies the given RSA PKCSv1.5 signature.
+ *
+ * @param dataToVerify
+ * @param signature
+ * @param masterKeyPath - Azure Key Vault key url.
+ * @return true if signature is valid, false if it is not valid
+ * @throws SQLServerException
+ */
+ private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify, byte[] signature, String masterKeyPath)
+ throws SQLServerException {
+ assert ((null != dataToVerify) && (0 != dataToVerify.length));
+ assert ((null != signature) && (0 != signature.length));
+
+ CryptographyClient cryptoClient = getCryptographyClient(masterKeyPath);
+ VerifyResult valid = cryptoClient.verify(SignatureAlgorithm.RS256, dataToVerify, signature);
+
+ return valid.isValid();
}
- if (!"RSA".equalsIgnoreCase(retrievedKey.key().kty().toString())
- && !"RSA-HSM".equalsIgnoreCase(retrievedKey.key().kty().toString())) {
- MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey"));
- Object[] msgArgs = {retrievedKey.key().kty().toString()};
- throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ /**
+ * Returns the public Key size in bytes.
+ *
+ * @param masterKeyPath - Azure Key Vault Key path
+ * @return Key size in bytes
+ * @throws SQLServerException when an error occurs
+ */
+ private int getAKVKeySize(String masterKeyPath) throws SQLServerException {
+ KeyVaultKey retrievedKey = getKeyVaultKey(masterKeyPath);
+ return retrievedKey.getKey().getN().length;
}
- return retrievedKey.key().n().length;
- }
+ private KeyVaultKey getKeyVaultKey(String masterKeyPath)
+ throws SQLServerException {
+ String[] keyTokens = masterKeyPath.split("/");
+ String keyName = keyTokens[keyTokens.length - 2];
+ String keyVersion = keyTokens[keyTokens.length - 1];
+ KeyClient keyClient = getKeyClient(masterKeyPath);
+ KeyVaultKey retrievedKey = keyClient.getKey(keyName, keyVersion);
+
+ if (null == retrievedKey) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyNotFound"));
+ Object[] msgArgs = {keyTokens[keyTokens.length - 1]};
+ throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ }
- @Override
- public boolean verifyColumnMasterKeyMetadata(String masterKeyPath, boolean allowEnclaveComputations,
- byte[] signature) throws SQLServerException {
- if (!allowEnclaveComputations)
- return false;
+ if (retrievedKey.getKeyType() != KeyType.RSA && retrievedKey.getKeyType() != KeyType.RSA_HSM) {
+ MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey"));
+ Object[] msgArgs = {retrievedKey.getKeyType().toString()};
+ throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
+ }
+ return retrievedKey;
+ }
- KeyStoreProviderCommon.validateNonEmptyMasterKeyPath(masterKeyPath);
+ private KeyClient getKeyClient(String masterKeyPath) {
+ if (cachedKeyClients.containsKey(masterKeyPath)) {
+ return cachedKeyClients.get(masterKeyPath);
+ }
+ String vaultUrl = getVaultUrl(masterKeyPath);
- try {
- MessageDigest md = MessageDigest.getInstance("SHA-256");
- md.update(name.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
- md.update(masterKeyPath.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
- // value of allowEnclaveComputations is always true here
- md.update("true".getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
+ KeyClient keyClient;
+ if (credential != null) {
+ keyClient = new KeyClientBuilder().credential(credential).vaultUrl(vaultUrl).buildClient();
+ } else {
+ keyClient = new KeyClientBuilder().pipeline(keyVaultPipeline).vaultUrl(vaultUrl).buildClient();
+ }
+ cachedKeyClients.putIfAbsent(masterKeyPath, keyClient);
+ return cachedKeyClients.get(masterKeyPath);
+ }
- byte[] dataToVerify = md.digest();
- if (null == dataToVerify) {
- throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null);
- }
+ private static String getVaultUrl(String masterKeyPath) {
+ String[] keyTokens = masterKeyPath.split("/");
+ String hostName = keyTokens[2];
+ return "https://" + hostName;
+ }
- // Sign the hash
- byte[] signedHash = AzureKeyVaultSignHashedData(dataToVerify, masterKeyPath);
- if (null == signedHash) {
- throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null);
+ @Override public boolean verifyColumnMasterKeyMetadata(String masterKeyPath, boolean allowEnclaveComputations,
+ byte[] signature) throws SQLServerException {
+ if (!allowEnclaveComputations) {
+ return false;
}
- // Validate the signature
- return AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath);
- } catch (NoSuchAlgorithmException e) {
- throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
+ KeyStoreProviderCommon.validateNonEmptyMasterKeyPath(masterKeyPath);
+
+ try {
+ MessageDigest md = MessageDigest.getInstance("SHA-256");
+ md.update(name.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
+ md.update(masterKeyPath.toLowerCase().getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
+ // value of allowEnclaveComputations is always true here
+ md.update("true".getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
+
+ byte[] dataToVerify = md.digest();
+ if (null == dataToVerify) {
+ throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null);
+ }
+
+ // Sign the hash
+ byte[] signedHash = AzureKeyVaultSignHashedData(dataToVerify, masterKeyPath);
+ if (null == signedHash) {
+ throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"),
+ null);
+ }
+
+ // Validate the signature
+ return AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath);
+ } catch (NoSuchAlgorithmException e) {
+ throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
+ }
}
- }
-
- private static List getTrustedEndpoints() {
- Properties mssqlJdbcProperties = getMssqlJdbcProperties();
- List trustedEndpoints = new ArrayList();
- boolean append = true;
- if (null != mssqlJdbcProperties) {
- String endpoints = mssqlJdbcProperties.getProperty(AKV_TRUSTED_ENDPOINTS_KEYWORD);
- if (null != endpoints && !endpoints.trim().isEmpty()) {
- endpoints = endpoints.trim();
- // Append if the list starts with a semicolon.
- if (';' != endpoints.charAt(0)) {
- append = false;
- } else {
- endpoints = endpoints.substring(1);
+
+ private static List getTrustedEndpoints() {
+ Properties mssqlJdbcProperties = getMssqlJdbcProperties();
+ List trustedEndpoints = new ArrayList();
+ boolean append = true;
+ if (null != mssqlJdbcProperties) {
+ String endpoints = mssqlJdbcProperties.getProperty(AKV_TRUSTED_ENDPOINTS_KEYWORD);
+ if (null != endpoints && !endpoints.trim().isEmpty()) {
+ endpoints = endpoints.trim();
+ // Append if the list starts with a semicolon.
+ if (';' != endpoints.charAt(0)) {
+ append = false;
+ } else {
+ endpoints = endpoints.substring(1);
+ }
+ String[] entries = endpoints.split(";");
+ for (String entry : entries) {
+ if (null != entry && !entry.trim().isEmpty()) {
+ trustedEndpoints.add(entry.trim());
+ }
+ }
+ }
}
- String[] entries = endpoints.split(";");
- for (String entry : entries) {
- if (null != entry && !entry.trim().isEmpty()) {
- trustedEndpoints.add(entry.trim());
- }
+ /*
+ * List of Azure trusted endpoints
+ * https://docs.microsoft.com/en-us/azure/key-vault/key-vault-secure-your-key-vault
+ */
+ if (append) {
+ trustedEndpoints.add("vault.azure.net");
+ trustedEndpoints.add("vault.azure.cn");
+ trustedEndpoints.add("vault.usgovcloudapi.net");
+ trustedEndpoints.add("vault.microsoftazure.de");
}
- }
+ return trustedEndpoints;
}
- /*
- * List of Azure trusted endpoints
- * https://docs.microsoft.com/en-us/azure/key-vault/key-vault-secure-your-key-vault
+
+ /**
+ * Attempt to read MSSQL_JDBC_PROPERTIES.
+ *
+ * @return corresponding Properties object or null if failed to read the file.
*/
- if (append) {
- trustedEndpoints.add("vault.azure.net");
- trustedEndpoints.add("vault.azure.cn");
- trustedEndpoints.add("vault.usgovcloudapi.net");
- trustedEndpoints.add("vault.microsoftazure.de");
- }
- return trustedEndpoints;
- }
-
- /**
- * Attempt to read MSSQL_JDBC_PROPERTIES.
- *
- * @return corresponding Properties object or null if failed to read the file.
- */
- private static Properties getMssqlJdbcProperties() {
- Properties props = null;
- try (FileInputStream in = new FileInputStream(MSSQL_JDBC_PROPERTIES)) {
- props = new Properties();
- props.load(in);
- } catch (IOException e) {
- if (akvLogger.isLoggable(Level.FINER)) {
- akvLogger.finer("Unable to load the mssql-jdbc.properties file: " + e);
- }
+ private static Properties getMssqlJdbcProperties() {
+ Properties props = null;
+ try (FileInputStream in = new FileInputStream(MSSQL_JDBC_PROPERTIES)) {
+ props = new Properties();
+ props.load(in);
+ } catch (IOException e) {
+ if (akvLogger.isLoggable(Level.FINER)) {
+ akvLogger.finer("Unable to load the mssql-jdbc.properties file: " + e);
+ }
+ }
+ return (null != props && !props.isEmpty()) ? props : null;
}
- return (null != props && !props.isEmpty()) ? props : null;
- }
}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
index 61a0e29ab..f236258bf 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
@@ -669,6 +669,8 @@ boolean getSendTemporalDataTypesAsStringForBulkCopy() {
String keyStoreLocation = null;
String keyStorePrincipalId = null;
+ String keyVaultProviderTenantId = null;
+
private ColumnEncryptionVersion serverColumnEncryptionVersion = ColumnEncryptionVersion.AE_NotSupported;
private String enclaveType = null;
@@ -1366,16 +1368,11 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke
}
break;
case KeyVaultClientSecret:
- // need a secret use use the secret method
+ // need a secret to use the secret method
if (null == keyStoreSecret) {
throw new SQLServerException(SQLServerException.getErrString("R_keyStoreSecretNotSet"), null);
- } else {
- SQLServerColumnEncryptionAzureKeyVaultProvider provider = new SQLServerColumnEncryptionAzureKeyVaultProvider(
- keyStorePrincipalId, keyStoreSecret);
- Map keyStoreMap = new HashMap();
- keyStoreMap.put(provider.getName(), provider);
- registerColumnEncryptionKeyStoreProviders(keyStoreMap);
}
+ registerKeyVaultProvider(keyStorePrincipalId, keyStoreSecret);
break;
case KeyVaultManagedIdentity:
SQLServerColumnEncryptionAzureKeyVaultProvider provider;
@@ -1384,7 +1381,7 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke
} else {
provider = new SQLServerColumnEncryptionAzureKeyVaultProvider();
}
- Map keyStoreMap = new HashMap();
+ Map keyStoreMap = new HashMap<>();
keyStoreMap.put(provider.getName(), provider);
registerColumnEncryptionKeyStoreProviders(keyStoreMap);
break;
@@ -1395,6 +1392,15 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke
}
}
+ private void registerKeyVaultProvider(String clientId, String clientKey) throws SQLServerException {
+ // need a secret to use the secret method
+ SQLServerColumnEncryptionAzureKeyVaultProvider provider = new SQLServerColumnEncryptionAzureKeyVaultProvider(
+ clientId, clientKey);
+ Map keyStoreMap = new HashMap<>();
+ keyStoreMap.put(provider.getName(), provider);
+ registerColumnEncryptionKeyStoreProviders(keyStoreMap);
+ }
+
/**
* Establish a physical database connection based on the user specified connection properties. Logon to the
* database.
@@ -1620,6 +1626,12 @@ Connection connectInternal(Properties propsIn,
keyStorePrincipalId = sPropValue;
}
+ sPropKey = SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_TENANT_ID.toString();
+ sPropValue = activeConnectionProperties.getProperty(sPropKey);
+ if (null != sPropValue) {
+ keyVaultProviderTenantId = sPropValue;
+ }
+
registerKeyStoreProviderOnConnection(keyStoreAuthentication, keyStoreSecret, keyStoreLocation);
if (null == globalCustomColumnEncryptionKeyStoreProviders) {
@@ -1631,11 +1643,10 @@ Connection connectInternal(Properties propsIn,
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null != sPropValue) {
String keyVaultColumnEncryptionProviderClientKey = sPropValue;
- SQLServerColumnEncryptionAzureKeyVaultProvider akvProvider = new SQLServerColumnEncryptionAzureKeyVaultProvider(
- keyVaultColumnEncryptionProviderClientId, keyVaultColumnEncryptionProviderClientKey);
- Map keyStoreMap = new HashMap();
- keyStoreMap.put(akvProvider.getName(), akvProvider);
- registerColumnEncryptionKeyStoreProviders(keyStoreMap);
+
+ registerKeyVaultProvider(
+ keyVaultColumnEncryptionProviderClientId,
+ keyVaultColumnEncryptionProviderClientKey);
}
}
}
@@ -4435,11 +4446,11 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
while (true) {
if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryPassword.toString())) {
- if (!adalContextExists()) {
+ if (!msalContextExists()) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_ADALMissing"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null, 0, null);
}
- fedAuthToken = SQLServerADAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
+ fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
authenticationString);
@@ -4521,17 +4532,17 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
sleepInterval = sleepInterval * 2;
}
}
- // else choose ADAL4J for integrated authentication. This option is supported for both windows and unix,
+ // else choose MSAL4J for integrated authentication. This option is supported for both windows and unix,
// so we don't need to check the
// OS version here.
else {
- // Check if ADAL4J library is available
- if (!adalContextExists()) {
+ // Check if MSAL4J library is available
+ if (!msalContextExists()) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_DLLandADALMissing"));
Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString};
throw new SQLServerException(form.format(msgArgs), null, 0, null);
}
- fedAuthToken = SQLServerADAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString);
+ fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString);
}
// Break out of the retry loop in successful case.
break;
@@ -4541,9 +4552,9 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
return fedAuthToken;
}
- private boolean adalContextExists() {
+ private boolean msalContextExists() {
try {
- Class.forName("com.microsoft.aad.adal4j.AuthenticationContext");
+ Class.forName("com.microsoft.aad.msal4j.PublicClientApplication");
} catch (ClassNotFoundException e) {
return false;
}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java
index d9f41510a..4eb192037 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java
@@ -363,6 +363,7 @@ enum SQLServerDriverStringProperty {
MSI_CLIENT_ID("msiClientId", ""),
KEY_VAULT_PROVIDER_CLIENT_ID("keyVaultProviderClientId", ""),
KEY_VAULT_PROVIDER_CLIENT_KEY("keyVaultProviderClientKey", ""),
+ KEY_VAULT_PROVIDER_TENANT_ID("keyVaultProviderTenantId", ""),
KEY_STORE_PRINCIPAL_ID("keyStorePrincipalId", ""),
CLIENT_CERTIFICATE("clientCertificate", ""),
CLIENT_KEY("clientKey", ""),
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerKeyVaultAuthenticationCallback.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerKeyVaultAuthenticationCallback.java
deleted file mode 100644
index c7ac13f5c..000000000
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerKeyVaultAuthenticationCallback.java
+++ /dev/null
@@ -1,26 +0,0 @@
-/*
- * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
- * available under the terms of the MIT License. See the LICENSE file in the project root for more information.
- */
-
-package com.microsoft.sqlserver.jdbc;
-
-/**
- * Provides a callback delegate which is to be implemented by the client code
- *
- */
-public interface SQLServerKeyVaultAuthenticationCallback {
-
- /**
- * Returns the acesss token of the authentication request
- *
- * @param authority
- * - Identifier of the authority, a URL.
- * @param resource
- * - Identifier of the target resource that is the recipient of the requested token, a URL.
- * @param scope
- * - The scope of the authentication request.
- * @return access token
- */
- String getAccessToken(String authority, String resource, String scope);
-}
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
similarity index 62%
rename from src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java
rename to src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
index a94ca3cc5..e8a318375 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerADAL4JUtils.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java
@@ -8,37 +8,46 @@
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.MessageFormat;
+import java.util.Collections;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
import java.util.logging.Level;
import javax.security.auth.kerberos.KerberosPrincipal;
-import com.microsoft.aad.adal4j.AuthenticationContext;
-import com.microsoft.aad.adal4j.AuthenticationException;
-import com.microsoft.aad.adal4j.AuthenticationResult;
+import com.microsoft.aad.msal4j.IAuthenticationResult;
+import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters;
+import com.microsoft.aad.msal4j.PublicClientApplication;
+import com.microsoft.aad.msal4j.UserNamePasswordParameters;
import com.microsoft.sqlserver.jdbc.SQLServerConnection.ActiveDirectoryAuthentication;
import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo;
-class SQLServerADAL4JUtils {
+class SQLServerMSAL4JUtils {
- static final private java.util.logging.Logger adal4jLogger = java.util.logging.Logger
- .getLogger("com.microsoft.sqlserver.jdbc.internals.SQLServerADAL4JUtils");
+ static final private java.util.logging.Logger logger = java.util.logging.Logger
+ .getLogger("com.microsoft.sqlserver.jdbc.SQLServerMSAL4JUtils");
static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password,
String authenticationString) throws SQLServerException {
ExecutorService executorService = Executors.newFixedThreadPool(1);
try {
- AuthenticationContext context = new AuthenticationContext(fedAuthInfo.stsurl, false, executorService);
- Future future = context.acquireToken(fedAuthInfo.spn,
- ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, user, password, null);
+ final PublicClientApplication clientApplication = PublicClientApplication
+ .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID)
+ .executorService(executorService)
+ .authority(fedAuthInfo.stsurl)
+ .build();
+ final CompletableFuture future = clientApplication.acquireToken(UserNamePasswordParameters.builder(
+ Collections.singleton(fedAuthInfo.spn + "/.default"),
+ user,
+ password.toCharArray()
+ ).build());
+
+ final IAuthenticationResult authenticationResult = future.get();
+ return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
- AuthenticationResult authenticationResult = future.get();
-
- return new SqlFedAuthToken(authenticationResult.getAccessToken(), authenticationResult.getExpiresOnDate());
} catch (MalformedURLException | InterruptedException e) {
throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
@@ -50,8 +59,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use
* correct format
*/
String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n");
- AuthenticationException correctedAuthenticationException = new AuthenticationException(
- correctedErrorMessage);
+ RuntimeException correctedAuthenticationException = new RuntimeException(correctedErrorMessage);
/*
* SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to match
@@ -75,19 +83,24 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
* principal_name@realm_name format
*/
KerberosPrincipal kerberosPrincipal = new KerberosPrincipal("username");
- String username = kerberosPrincipal.getName();
+ String user = kerberosPrincipal.getName();
- if (adal4jLogger.isLoggable(Level.FINE)) {
- adal4jLogger.fine(adal4jLogger.toString() + " realm name is:" + kerberosPrincipal.getRealm());
+ if (logger.isLoggable(Level.FINE)) {
+ logger.fine(logger.toString() + " realm name is:" + kerberosPrincipal.getRealm());
}
- AuthenticationContext context = new AuthenticationContext(fedAuthInfo.stsurl, false, executorService);
- Future future = context.acquireToken(fedAuthInfo.spn,
- ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID, username, null, null);
-
- AuthenticationResult authenticationResult = future.get();
-
- return new SqlFedAuthToken(authenticationResult.getAccessToken(), authenticationResult.getExpiresOnDate());
+ final PublicClientApplication clientApplication = PublicClientApplication
+ .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID)
+ .executorService(executorService)
+ .authority(fedAuthInfo.stsurl)
+ .build();
+ final CompletableFuture future = clientApplication.acquireToken(IntegratedWindowsAuthenticationParameters.builder(
+ Collections.singleton(fedAuthInfo.spn + "/.default"),
+ user
+ ).build());
+
+ final IAuthenticationResult authenticationResult = future.get();
+ return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
} catch (InterruptedException | IOException e) {
throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
@@ -103,8 +116,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
* correct format
*/
String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n");
- AuthenticationException correctedAuthenticationException = new AuthenticationException(
- correctedErrorMessage);
+ RuntimeException correctedAuthenticationException = new RuntimeException(correctedErrorMessage);
/*
* SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java
index a361c0f59..23cb72514 100644
--- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java
@@ -537,6 +537,8 @@ protected Object[][] getContents() {
"Both \"keyStoreSecret\" and \"keyStoreLocation\" must be set, if \"keyStoreAuthentication=JavaKeyStorePassword\" has been specified in the connection string."},
{"R_keyStoreSecretNotSet",
"\"keyStoreSecret\" must be set, if \"keyStoreAuthentication=KeyVaultClientSecret\" has been specified in the connection string."},
+ {"R_keyVaultProviderTenantIdNotSet",
+ "\"keyVaultProviderTenantId\" must be set, if \"keyStoreAuthentication=KeyVaultClientSecret\" has been specified in the connection string."},
{"R_certificateStoreInvalidKeyword",
"Cannot set \"keyStoreSecret\", if \"keyStoreAuthentication=CertificateStore\" has been specified in the connection string."},
{"R_certificateStoreLocationNotSet",
diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ScopeTokenCache.java b/src/main/java/com/microsoft/sqlserver/jdbc/ScopeTokenCache.java
new file mode 100644
index 000000000..a3758fb79
--- /dev/null
+++ b/src/main/java/com/microsoft/sqlserver/jdbc/ScopeTokenCache.java
@@ -0,0 +1,57 @@
+package com.microsoft.sqlserver.jdbc;
+
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenRequestContext;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
+import reactor.core.publisher.FluxSink;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.ReplayProcessor;
+
+
+/**
+ * A token cache that supports caching a token and refreshing it.
+ */
+class ScopeTokenCache {
+
+ private final AtomicBoolean wip;
+ private AccessToken cache;
+ private final ReplayProcessor emitterProcessor = ReplayProcessor.create(1);
+ private final FluxSink sink = emitterProcessor.sink(FluxSink.OverflowStrategy.BUFFER);
+ private final Function> getNew;
+ private TokenRequestContext request;
+
+ /**
+ * Creates an instance of RefreshableTokenCredential with default scheme "Bearer".
+ *
+ * @param getNew a method to get a new token
+ */
+ ScopeTokenCache(Function> getNew) {
+ this.wip = new AtomicBoolean(false);
+ this.getNew = getNew;
+ }
+
+ void setRequest(TokenRequestContext request) {
+ this.request = request;
+ }
+
+ /**
+ * Asynchronously get a token from either the cache or replenish the cache with a new token.
+ * @return a Publisher that emits an AccessToken
+ */
+ Mono getToken() {
+ if (cache != null && !cache.isExpired()) {
+ return Mono.just(cache);
+ }
+ return Mono.defer(() -> {
+ if (!wip.getAndSet(true)) {
+ return getNew.apply(request).doOnNext(ac -> cache = ac)
+ .doOnNext(sink::next)
+ .doOnError(sink::error)
+ .doOnTerminate(() -> wip.set(false));
+ } else {
+ return emitterProcessor.next();
+ }
+ });
+ }
+}
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java
index ba4125f05..c2fe9e4f9 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java
@@ -4,6 +4,7 @@
*/
package com.microsoft.sqlserver.jdbc.AlwaysEncrypted;
+import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
@@ -17,25 +18,19 @@
import java.sql.Time;
import java.sql.Timestamp;
import java.util.LinkedList;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import com.azure.core.credential.TokenCredential;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
-import com.microsoft.aad.adal4j.AuthenticationContext;
-import com.microsoft.aad.adal4j.AuthenticationResult;
-import com.microsoft.aad.adal4j.ClientCredential;
import com.microsoft.sqlserver.jdbc.RandomData;
import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider;
import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionJavaKeyStoreProvider;
import com.microsoft.sqlserver.jdbc.SQLServerConnection;
import com.microsoft.sqlserver.jdbc.SQLServerException;
-import com.microsoft.sqlserver.jdbc.SQLServerKeyVaultAuthenticationCallback;
import com.microsoft.sqlserver.jdbc.SQLServerPreparedStatement;
import com.microsoft.sqlserver.jdbc.SQLServerResultSet;
import com.microsoft.sqlserver.jdbc.SQLServerStatement;
@@ -91,18 +86,15 @@ public void testJksName(String serverName, String url, String protocol) throws E
*/
@ParameterizedTest
@MethodSource("enclaveParams")
+ @Tag(Constants.reqExternalSetup)
public void testAkvName(String serverName, String url, String protocol) throws Exception {
setAEConnectionString(serverName, url, protocol);
- try {
- SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(
- authenticationCallback);
- String keystoreName = "keystoreName";
- akv.setName(keystoreName);
- assertTrue(akv.getName().equals(keystoreName));
- } catch (SQLServerException e) {
- fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
- }
+ SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(
+ applicationClientID, applicationKey);
+ String keystoreName = "keystoreName";
+ akv.setName(keystoreName);
+ assertTrue(akv.getName().equals(keystoreName));
}
/*
@@ -134,10 +126,10 @@ public void testBadAkv(String serverName, String url, String protocol) throws Ex
try {
SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(
- (SQLServerKeyVaultAuthenticationCallback) null);
+ (TokenCredential) null);
fail(TestResource.getResource("R_expectedExceptionNotThrown"));
- } catch (SQLServerException e) {
- assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_NullValue")));
+ } catch (NullPointerException exception) {
+ assertNull(exception.getMessage());
}
}
@@ -185,11 +177,7 @@ public void testAkvBadEncryptColumnEncryptionKey(String serverName, String url,
setAEConnectionString(serverName, url, protocol);
SQLServerColumnEncryptionAzureKeyVaultProvider akv = null;
- try {
- akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback);
- } catch (SQLServerException e) {
- fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
- }
+ akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey);
// null encryptedColumnEncryptionKey
try {
@@ -268,11 +256,7 @@ public void testAkvDecryptColumnEncryptionKey(String serverName, String url, Str
setAEConnectionString(serverName, url, protocol);
SQLServerColumnEncryptionAzureKeyVaultProvider akv = null;
- try {
- akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback);
- } catch (SQLServerException e) {
- fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
- }
+ akv = new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey);
// null akvpath
try {
@@ -2243,23 +2227,4 @@ void testNumerics(SQLServerStatement stmt, String cekName, String[][] table, Str
testRichQuery(stmt, NUMERIC_TABLE_AE, table, values2);
}
}
-
- SQLServerKeyVaultAuthenticationCallback authenticationCallback = new SQLServerKeyVaultAuthenticationCallback() {
- // @Override
- ExecutorService service = Executors.newFixedThreadPool(2);
-
- public String getAccessToken(String authority, String resource, String scope) {
-
- AuthenticationResult result = null;
- try {
- AuthenticationContext context = new AuthenticationContext(authority, false, service);
- ClientCredential cred = new ClientCredential(applicationClientID, applicationKey);
- Future future = context.acquireToken(resource, cred, null);
- result = future.get();
- } catch (Exception e) {
- fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
- }
- return result.getAccessToken();
- }
- };
}
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java
index 552ef2012..b4c28a17c 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java
@@ -7,6 +7,7 @@
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import com.microsoft.aad.msal4j.MsalServiceException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
@@ -299,7 +300,9 @@ public void testNumericAkvWithBadCred() throws SQLException {
testNumericAKV(connStr);
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
- assert (e.getMessage().contains("AuthenticationException"));
+ assertTrue(e.getCause() instanceof MsalServiceException);
+ // https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes
+ assertTrue(e.getMessage().contains("AADSTS700016"));
}
}
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java
index 10e4d82fc..0c85e7ea5 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java
@@ -307,11 +307,10 @@ public static void testVerifyCMKNoEnclave() {
}
try {
- SQLServerColumnEncryptionAzureKeyVaultProvider aksp = new SQLServerColumnEncryptionAzureKeyVaultProvider("",
- "");
- assertFalse(aksp.verifyColumnMasterKeyMetadata(null, false, null));
+ SQLServerColumnEncryptionAzureKeyVaultProvider aksp = new SQLServerColumnEncryptionAzureKeyVaultProvider(
+ "","");
} catch (SQLServerException e) {
- fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
+ assertEquals(e.getMessage(), "Client ID cannot be null.");
}
}
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java
index 675fa4341..30c696d3d 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java
@@ -7,22 +7,27 @@
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
+import com.microsoft.aad.msal4j.IAuthenticationResult;
+import com.microsoft.aad.msal4j.PublicClientApplication;
+import com.microsoft.aad.msal4j.UserNamePasswordParameters;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
+import java.util.Collections;
+import java.util.Date;
import java.util.Locale;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
import java.util.logging.LogManager;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
-import com.microsoft.aad.adal4j.AuthenticationContext;
-import com.microsoft.aad.adal4j.AuthenticationResult;
import com.microsoft.sqlserver.testframework.Constants;
import com.microsoft.sqlserver.jdbc.SQLServerException;
import com.microsoft.sqlserver.jdbc.TestResource;
@@ -142,11 +147,24 @@ public static void getConfigs() throws Exception {
*/
static void getFedauthInfo() {
try {
- AuthenticationContext context = new AuthenticationContext(stsurl, false, Executors.newFixedThreadPool(1));
- Future future = context.acquireToken(spn, fedauthClientId, azureUserName,
- azurePassword, null);
- secondsBeforeExpiration = future.get().getExpiresAfter();
- accessToken = future.get().getAccessToken();
+
+ final PublicClientApplication clientApplication = PublicClientApplication
+ .builder(fedauthClientId)
+ .executorService(Executors.newFixedThreadPool(1))
+ .authority(stsurl)
+ .build();
+ final CompletableFuture future = clientApplication.acquireToken(
+ UserNamePasswordParameters.builder(
+ Collections.singleton(spn + "/.default"),
+ azureUserName,
+ azurePassword.toCharArray()
+ ).build());
+
+ final IAuthenticationResult authenticationResult = future.get();
+
+ secondsBeforeExpiration = TimeUnit.MILLISECONDS
+ .toSeconds(authenticationResult.expiresOnDate().getTime() - new Date().getTime());
+ accessToken = authenticationResult.accessToken();
} catch (Exception e) {
fail(e.getMessage());
}
diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java
index 65f1c2ac0..08942715b 100644
--- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java
+++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthWithAE.java
@@ -15,9 +15,6 @@
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
@@ -26,9 +23,6 @@
import org.junit.platform.runner.JUnitPlatform;
import org.junit.runner.RunWith;
-import com.microsoft.aad.adal4j.AuthenticationContext;
-import com.microsoft.aad.adal4j.AuthenticationResult;
-import com.microsoft.aad.adal4j.ClientCredential;
import com.microsoft.sqlserver.jdbc.RandomUtil;
import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider;
import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionJavaKeyStoreProvider;
@@ -36,7 +30,6 @@
import com.microsoft.sqlserver.jdbc.SQLServerConnection;
import com.microsoft.sqlserver.jdbc.SQLServerDataSource;
import com.microsoft.sqlserver.jdbc.SQLServerException;
-import com.microsoft.sqlserver.jdbc.SQLServerKeyVaultAuthenticationCallback;
import com.microsoft.sqlserver.jdbc.TestUtils;
import com.microsoft.sqlserver.testframework.AbstractSQLGenerator;
import com.microsoft.sqlserver.testframework.Constants;
@@ -128,7 +121,7 @@ public void testFedAuthWithAE_AKV() throws SQLException {
dropCMK(stmt, cmkName3);
setupCMK_AKVOld(cmkName3, stmt);
- createCEK(cmkName3, setupKeyStoreProvider_AKVOld(), stmt, keyIDs[0]);
+ createCEK(cmkName3, setupKeyStoreProvider_AKVNew(), stmt, keyIDs[0]);
createCharTable(stmt, charTableOld);
populateCharNormalCase(charValues, connection, charTableOld);
@@ -307,27 +300,27 @@ private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKVNew()
new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey));
}
- private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKVOld() throws SQLServerException {
- ExecutorService service = Executors.newFixedThreadPool(2);
- SQLServerKeyVaultAuthenticationCallback authenticationCallback = new SQLServerKeyVaultAuthenticationCallback() {
- @Override
- public String getAccessToken(String authority, String resource, String scope) {
- AuthenticationResult result = null;
- try {
- AuthenticationContext context = new AuthenticationContext(authority, false, service);
- ClientCredential cred = new ClientCredential(applicationClientID, applicationKey);
-
- Future future = context.acquireToken(resource, cred, null);
- result = future.get();
- return result.getAccessToken();
- } catch (Exception e) {
- fail(e.getMessage());
- return null;
- }
- }
- };
- return new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback);
- }
+// private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKVOld() throws SQLServerException {
+// ExecutorService service = Executors.newFixedThreadPool(2);
+// SQLServerKeyVaultAuthenticationCallback authenticationCallback = new SQLServerKeyVaultAuthenticationCallback() {
+// @Override
+// public String getAccessToken(String authority, String resource, String scope) {
+// AuthenticationResult result = null;
+// try {
+// AuthenticationContext context = new AuthenticationContext(authority, false, service);
+// ClientCredential cred = new ClientCredential(applicationClientID, applicationKey);
+//
+// Future future = context.acquireToken(resource, cred, null);
+// result = future.get();
+// return result.getAccessToken();
+// } catch (Exception e) {
+// fail(e.getMessage());
+// return null;
+// }
+// }
+// };
+// return new SQLServerColumnEncryptionAzureKeyVaultProvider(authenticationCallback);
+// }
private SQLServerColumnEncryptionKeyStoreProvider registerAKVProvider(
SQLServerColumnEncryptionKeyStoreProvider provider) throws SQLServerException {
diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java
index 914831a3c..56bd627f0 100644
--- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java
+++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java
@@ -59,6 +59,7 @@ public abstract class AbstractTest {
protected static String applicationClientID = null;
protected static String applicationKey = null;
protected static String[] keyIDs = null;
+ protected static String tenantID = null;
protected static String[] enclaveServer = null;
protected static String[] enclaveAttestationUrl = null;
@@ -136,6 +137,7 @@ public static void setup() throws Exception {
javaKeyPath = TestUtils.getCurrentClassPath() + Constants.JKS_NAME;
keyIDs = getConfiguredProperty("keyID", "").split(Constants.SEMI_COLON);
+ tenantID = getConfiguredProperty("tenantID");
windowsKeyPath = getConfiguredProperty("windowsKeyPath");
String prop;