diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
index 6a0ee2e0e0..fb9d62c7e6 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
@@ -2254,6 +2254,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");
+
+ if (_newDbConnectionPoolAuthenticationContext != null)
+ {
+ _dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
+ }
}
}
else if (!attemptRefreshTokenLocked)
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
index 068b37dc71..294f9c8b7e 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
@@ -2680,6 +2680,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");
+
+ if (_newDbConnectionPoolAuthenticationContext != null)
+ {
+ _dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
+ }
}
}
else if (!attemptRefreshTokenLocked)
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj
index 407dd9af1a..24c84ed402 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj
@@ -170,6 +170,7 @@
+
@@ -275,6 +276,7 @@
+
@@ -298,30 +300,19 @@
-
-
-
-
-
-
-
+
+
+
+
+
+
+
-
+
-
+
runtime; build; native; contentfiles; analyzers; buildtransitive
all
@@ -334,8 +325,7 @@
all
-
+
@@ -352,7 +342,7 @@
-
+
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AADFedAuthTokenRefreshTest/AADFedAuthTokenRefreshTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AADFedAuthTokenRefreshTest/AADFedAuthTokenRefreshTest.cs
new file mode 100644
index 0000000000..875d755a59
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AADFedAuthTokenRefreshTest/AADFedAuthTokenRefreshTest.cs
@@ -0,0 +1,74 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.Data.SqlClient.ManualTesting.Tests
+{
+ public class AADFedAuthTokenRefreshTest
+ {
+ private readonly ITestOutputHelper _testOutputHelper;
+
+ public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
+ {
+ _testOutputHelper = testOutputHelper;
+ }
+
+ [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
+ public void FedAuthTokenRefreshTest()
+ {
+ string connectionString = DataTestUtility.AADPasswordConnectionString;
+
+ using (SqlConnection connection = new SqlConnection(connectionString))
+ {
+ connection.Open();
+
+ string oldTokenHash = "";
+ DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash);
+ Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute.");
+
+ // Convert and display the old expiry into local time which should be in 1 minute from now
+ DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local);
+ LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}");
+ TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
+ Assert.InRange(timeDiff.TotalSeconds, 0, 60);
+
+ // Check if connection is still alive to continue further testing
+ string result = "";
+ SqlCommand cmd = connection.CreateCommand();
+ cmd.CommandText = "select @@version";
+ result = $"{cmd.ExecuteScalar()}";
+ Assert.True(result != string.Empty, "The connection's command must return a value");
+
+ // The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
+ using (SqlConnection connection2 = new SqlConnection(connectionString))
+ {
+ connection2.Open();
+
+ // Check if connection is alive
+ cmd = connection2.CreateCommand();
+ cmd.CommandText = "select 1";
+ result = $"{cmd.ExecuteScalar()}";
+ Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");
+
+ string newTokenHash = "";
+ DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash);
+ DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local);
+ LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}");
+
+ Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical.");
+ Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");
+ }
+ }
+ }
+
+ private void LogInfo(string message)
+ {
+ _testOutputHelper.WriteLine(message);
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/FedAuthTokenHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/FedAuthTokenHelper.cs
new file mode 100644
index 0000000000..26d2477b9b
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/FedAuthTokenHelper.cs
@@ -0,0 +1,108 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Linq;
+using System.Reflection;
+
+namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals
+{
+ internal static class FedAuthTokenHelper
+ {
+ internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash)
+ {
+ try
+ {
+ object authenticationContextValueObj = GetAuthenticationContextValue(connection);
+
+ DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
+
+ tokenHash = GetTokenHash(authenticationContextValueObj);
+
+ return expirationTimeProperty;
+ }
+ catch (Exception)
+ {
+ tokenHash = "";
+ return null;
+ }
+ }
+
+ internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash)
+ {
+ try
+ {
+ object authenticationContextValueObj = GetAuthenticationContextValue(connection);
+
+ DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
+
+ expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire);
+
+ FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
+ expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty);
+
+ tokenHash = GetTokenHash(authenticationContextValueObj);
+
+ return expirationTimeProperty;
+ }
+ catch (Exception)
+ {
+ tokenHash = "";
+ return null;
+ }
+ }
+
+ internal static string GetTokenHash(object authenticationContextValueObj)
+ {
+ try
+ {
+ Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));
+
+ Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken");
+
+ Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType };
+
+ ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
+
+ Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper");
+
+ ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
+
+ object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { });
+
+ MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null);
+
+ byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
+
+ object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { });
+ FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance);
+ accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes);
+
+ string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj });
+
+ return tokenHash;
+ }
+ catch (Exception)
+ {
+ return "";
+ }
+ }
+
+ internal static object GetAuthenticationContextValue(SqlConnection connection)
+ {
+ object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);
+
+ object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);
+
+ IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null);
+
+ object authenticationContextObj = authenticationContexts.Cast