Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample code improvements around token caching #2821

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/samples/AzureKeyVaultProviderExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ public static void Main(string[] args)
}
}

// Maintain an instance of the ClientCredential object to take advantage of underlying token caching
private static ClientCredential clientCredential = new ClientCredential(s_clientId, s_clientSecret);

public static async Task<string> AzureActiveDirectoryAuthenticationCallback(string authority, string resource, string scope)
{
var authContext = new AuthenticationContext(authority);
ClientCredential clientCred = new ClientCredential(s_clientId, s_clientSecret);
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCred);
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCredential);
if (result == null)
{
throw new InvalidOperationException($"Failed to retrieve an access token for {resource}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ static void Main(string[] args)
}
}

// Maintain an instance of the ClientCredential object to take advantage of underlying token caching
private static ClientCredential clientCredential = new ClientCredential(s_clientId, s_clientSecret);

public static async Task<string> AzureActiveDirectoryAuthenticationCallback(string authority, string resource, string scope)
{
var authContext = new AuthenticationContext(authority);
ClientCredential clientCred = new ClientCredential(s_clientId, s_clientSecret);
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCred);
AuthenticationResult result = await authContext.AcquireTokenAsync(resource, clientCredential);
if (result == null)
{
throw new InvalidOperationException($"Failed to retrieve an access token for {resource}");
Expand Down
44 changes: 32 additions & 12 deletions doc/samples/CustomDeviceCodeFlowAzureAuthenticationProvider.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//<Snippet1>
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Data.SqlClient;
using Microsoft.Identity.Client;

namespace CustomAuthenticationProviderExamples
{
Expand All @@ -12,28 +14,46 @@ namespace CustomAuthenticationProviderExamples
/// </summary>
public class CustomDeviceCodeFlowAzureAuthenticationProvider : SqlAuthenticationProvider
{
private const string clientId = "my-client-id";
private const string clientName = "My Application Name";
private const string s_defaultScopeSuffix = "/.default";

// Maintain a copy of the PublicClientApplication object to cache the underlying access tokens it provides
private static IPublicClientApplication pcApplication;

public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters)
{
string clientId = "my-client-id";
string clientName = "My Application Name";
string s_defaultScopeSuffix = "/.default";

string[] scopes = new string[] { parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix };

IPublicClientApplication app = PublicClientApplicationBuilder.Create(clientId)
.WithAuthority(parameters.Authority)
.WithClientName(clientName)
.WithRedirectUri("https://login.microsoftonline.com/common/oauth2/nativeclient")
IPublicClientApplication app = pcApplication;
if (app == null)
{
pcApplication = app = PublicClientApplicationBuilder.Create(clientId)
.WithAuthority(parameters.Authority)
.WithClientName(clientName)
.WithRedirectUri("https://login.microsoftonline.com/common/oauth2/nativeclient")
.Build();
}

AuthenticationResult result;

try
{
IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
result = await app.AcquireTokenSilent(scopes, accounts.FirstOrDefault()).ExecuteAsync();
}
catch (MsalUiRequiredException)
{
result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => CustomDeviceFlowCallback(deviceCodeResult)).ExecuteAsync();
}

AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => CustomDeviceFlowCallback(deviceCodeResult)).ExecuteAsync();
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
}

public override bool IsSupported(SqlAuthenticationMethod authenticationMethod) => authenticationMethod.Equals(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow);

private Task CustomDeviceFlowCallback(DeviceCodeResult result)
private static Task<int> CustomDeviceFlowCallback(DeviceCodeResult result)
{
Console.WriteLine(result.Message);
return Task.FromResult(0);
Expand Down
53 changes: 36 additions & 17 deletions doc/samples/SqlConnection_AccessTokenCallback.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using System;
using System.Data;
// <Snippet1>
using Microsoft.Data.SqlClient;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Microsoft.Data.SqlClient;

class Program
{
Expand All @@ -12,25 +15,41 @@ static void Main()
Console.ReadLine();
}

const string defaultScopeSuffix = "/.default";

// Reuse credential objects to take advantage of underlying token caches
private static ConcurrentDictionary<string, DefaultAzureCredential> credentials = new ConcurrentDictionary<string, DefaultAzureCredential>();

// Use a shared callback function for connections that should be in the same connection pool
private static Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> myAccessTokenCallback =
async (authParams, cancellationToken) =>
{
string scope = authParams.Resource.EndsWith(defaultScopeSuffix)
? authParams.Resource
: $"{authParams.Resource}{defaultScopeSuffix}";

DefaultAzureCredentialOptions options = new DefaultAzureCredentialOptions();
options.ManagedIdentityClientId = authParams.UserId;

// Reuse the same credential object if we are using the same MI Client Id
AccessToken token = await credentials.GetOrAdd(authParams.UserId, new DefaultAzureCredential(options)).GetTokenAsync(
new TokenRequestContext(new string[] { scope }),
cancellationToken);

return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
};

private static void OpenSqlConnection()
{
const string defaultScopeSuffix = "/.default";
string connectionString = GetConnectionString();
DefaultAzureCredential credential = new();
// (Optional) Pass a User-Assigned Managed Identity Client ID.
// This will ensure different MI Client IDs are in different connection pools.
string connectionString = "Server=myServer.database.windows.net;Encrypt=Mandatory;UserId=<ManagedIdentitityClientId>;";

using (SqlConnection connection = new(connectionString)
using (SqlConnection connection = new SqlConnection(connectionString)
{
AccessTokenCallback = async (authParams, cancellationToken) =>
{
string scope = authParams.Resource.EndsWith(defaultScopeSuffix)
? authParams.Resource
: $"{authParams.Resource}{defaultScopeSuffix}";
AccessToken token = await credential.GetTokenAsync(
new TokenRequestContext([scope]),
cancellationToken);

return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
}
// The callback function is part of the connection pool key. Using a static callback function
// ensures connections will not create a new pool per connection just for the callback.
AccessTokenCallback = myAccessTokenCallback
})
{
connection.Open();
Expand Down