Skip to content

Commit

Permalink
API | AccessTokenCallback support (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Jun 27, 2023
1 parent 2b31810 commit 8fad4a4
Show file tree
Hide file tree
Showing 23 changed files with 653 additions and 85 deletions.
35 changes: 35 additions & 0 deletions doc/samples/SqlConnection_AccessTokenCallback.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Data;
// <Snippet1>
using Microsoft.Data.SqlClient;
using Azure.Identity;

class Program
{
static void Main()
{
OpenSqlConnection();
Console.ReadLine();
}

private static void OpenSqlConnection()
{
string connectionString = GetConnectionString();
using (SqlConnection connection = new SqlConnection("Data Source=contoso.database.windows.net;Initial Catalog=AdventureWorks;")
{
AccessTokenCallback = async (authParams, cancellationToken) =>
{
var cred = new DefaultAzureCredential();
string scope = authParams.Resource.EndsWith(s_defaultScopeSuffix) ? authParams.Resource : authParams.Resource + s_defaultScopeSuffix;
var token = await cred.GetTokenAsync(new TokenRequestContext(new[] { scope }), cancellationToken);
return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
}
})
{
connection.Open();
Console.WriteLine("ServerVersion: {0}", connection.ServerVersion);
Console.WriteLine("State: {0}", connection.State);
}
}
}
// </Snippet1>
16 changes: 16 additions & 0 deletions doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ using (SqlConnection connection = new SqlConnection(connectionString))
<value>The access token for the connection.</value>
<remarks>To be added.</remarks>
</AccessToken>
<AccessTokenCallback>
<summary>Gets or sets the access token callback for the connection.</summary>
<value>
The Func that takes a <see cref="SqlAuthenticationParameters" /> and <see cref="System.Threading.CancellationToken" /> and returns a <see cref="SqlAuthenticationToken" />.</value>
<remarks>
<format type="text/markdown"><![CDATA[
## Examples
The following example demonstrates how to define and set an <xref:Microsoft.Data.SqlClient.AccessTokenCallback>.
[!code-csharp[SqlConnection_AccessTokenCallback Example#1](~/../sqlclient/doc/samples/SqlConnection_AccessTokenCallback.cs#1)]
]]></format>
</remarks>
<exception cref="T:System.InvalidOperationException">The AccessTokenCallback is combined with other conflicting authentication configurations.</exception>
</AccessTokenCallback>
<BeginDbTransaction>
<param name="isolationLevel">To be added.</param>
<summary>To be added.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available.
// New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future.

[assembly: System.CLSCompliant(true)]
namespace Microsoft.Data
{
Expand Down Expand Up @@ -839,6 +840,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ClientConnectionId/*'/>
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public System.Guid ClientConnectionId { get { throw null; } }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
public System.Func<SqlAuthenticationParameters, System.Threading.CancellationToken, System.Threading.Tasks.Task<SqlAuthenticationToken>> AccessTokenCallback { get { throw null; } set { } }

///
/// for internal test only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;

Expand Down Expand Up @@ -272,7 +274,7 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProvidersNames()
}

/// <summary>
/// This function returns a list of the names of the custom providers currently registered. If the
/// This function returns a list of the names of the custom providers currently registered. If the
/// instance-level cache is not empty, that cache is used, else the global cache is used.
/// </summary>
/// <returns>Combined list of provider names</returns>
Expand Down Expand Up @@ -344,7 +346,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<st
new(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}
Expand Down Expand Up @@ -584,7 +586,7 @@ public override string ConnectionString
}
set
{
if (_credential != null || _accessToken != null)
if (_credential != null || _accessToken != null || _accessTokenCallback != null)
{
SqlConnectionString connectionOptions = new SqlConnectionString(value);
if (_credential != null)
Expand Down Expand Up @@ -620,12 +622,18 @@ public override string ConnectionString

CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
}
else if (_accessToken != null)

if (_accessToken != null)
{
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(connectionOptions);
}

if (_accessTokenCallback != null)
{
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
}
}
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken));
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
Expand Down Expand Up @@ -685,11 +693,34 @@ public string AccessToken
}

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
_accessToken = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> AccessTokenCallback
{
get { return _accessTokenCallback; }
set
{
// If a connection is connecting or is ever opened, AccessToken callback cannot be set
if (!InnerConnection.AllowSetConnectionString)
{
throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State);
}

if (value != null)
{
// Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
}

ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
_accessTokenCallback = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
Expand Down Expand Up @@ -970,6 +1001,7 @@ public SqlCredential Credential
}

CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);

if (_accessToken != null)
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
Expand All @@ -979,7 +1011,7 @@ public SqlCredential Credential
_credential = value;

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
}
}

Expand Down Expand Up @@ -1026,6 +1058,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
}

if(_accessTokenCallback != null)
{
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
}
}

// CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict
// with the keys used in connection string and credential
// If there is any conflict, it throws InvalidOperationException
// This is to be used setter of ConnectionString and AccessTokenCallback properties
private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions)
{
if (UsesIntegratedSecurity(connectionOptions))
{
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity();
}

if (UsesAuthentication(connectionOptions))
{
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
}

if(_accessToken != null)
{
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/DbProviderFactory/*' />
Expand Down Expand Up @@ -2128,7 +2187,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity)
Expand Down Expand Up @@ -2177,7 +2236,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);

Expand Down Expand Up @@ -2216,7 +2275,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
if (con != null)
con.Dispose();
}
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);

SqlConnectionFactory.SingletonInstance.ClearPool(key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down
Loading

0 comments on commit 8fad4a4

Please sign in to comment.