diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 6a0ee2e0e0..fb9d62c7e6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -2254,6 +2254,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo) { // GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext. Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null."); + + if (_newDbConnectionPoolAuthenticationContext != null) + { + _dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext); + } } } else if (!attemptRefreshTokenLocked) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 068b37dc71..294f9c8b7e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -2680,6 +2680,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo) { // GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext. Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null."); + + if (_newDbConnectionPoolAuthenticationContext != null) + { + _dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext); + } } } else if (!attemptRefreshTokenLocked) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 407dd9af1a..24c84ed402 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -170,6 +170,7 @@ + @@ -275,6 +276,7 @@ + @@ -298,30 +300,19 @@ - - - - - - - + + + + + + + - + - + runtime; build; native; contentfiles; analyzers; buildtransitive all @@ -334,8 +325,7 @@ all - + @@ -352,7 +342,7 @@ - + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AADFedAuthTokenRefreshTest/AADFedAuthTokenRefreshTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AADFedAuthTokenRefreshTest/AADFedAuthTokenRefreshTest.cs new file mode 100644 index 0000000000..875d755a59 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AADFedAuthTokenRefreshTest/AADFedAuthTokenRefreshTest.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public class AADFedAuthTokenRefreshTest + { + private readonly ITestOutputHelper _testOutputHelper; + + public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))] + public void FedAuthTokenRefreshTest() + { + string connectionString = DataTestUtility.AADPasswordConnectionString; + + using (SqlConnection connection = new SqlConnection(connectionString)) + { + connection.Open(); + + string oldTokenHash = ""; + DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash); + Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute."); + + // Convert and display the old expiry into local time which should be in 1 minute from now + DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local); + LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}"); + TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now; + Assert.InRange(timeDiff.TotalSeconds, 0, 60); + + // Check if connection is still alive to continue further testing + string result = ""; + SqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = "select @@version"; + result = $"{cmd.ExecuteScalar()}"; + Assert.True(result != string.Empty, "The connection's command must return a value"); + + // The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute. + using (SqlConnection connection2 = new SqlConnection(connectionString)) + { + connection2.Open(); + + // Check if connection is alive + cmd = connection2.CreateCommand(); + cmd.CommandText = "select 1"; + result = $"{cmd.ExecuteScalar()}"; + Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh."); + + string newTokenHash = ""; + DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash); + DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local); + LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}"); + + Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical."); + Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time."); + } + } + } + + private void LogInfo(string message) + { + _testOutputHelper.WriteLine(message); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/FedAuthTokenHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/FedAuthTokenHelper.cs new file mode 100644 index 0000000000..26d2477b9b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/FedAuthTokenHelper.cs @@ -0,0 +1,108 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Linq; +using System.Reflection; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals +{ + internal static class FedAuthTokenHelper + { + internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash) + { + try + { + object authenticationContextValueObj = GetAuthenticationContextValue(connection); + + DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null); + + tokenHash = GetTokenHash(authenticationContextValueObj); + + return expirationTimeProperty; + } + catch (Exception) + { + tokenHash = ""; + return null; + } + } + + internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash) + { + try + { + object authenticationContextValueObj = GetAuthenticationContextValue(connection); + + DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null); + + expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire); + + FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance); + expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty); + + tokenHash = GetTokenHash(authenticationContextValueObj); + + return expirationTimeProperty; + } + catch (Exception) + { + tokenHash = ""; + return null; + } + } + + internal static string GetTokenHash(object authenticationContextValueObj) + { + try + { + Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection)); + + Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken"); + + Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType }; + + ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + + Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper"); + + ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); + + object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { }); + + MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null); + + byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null); + + object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { }); + FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance); + accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes); + + string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj }); + + return tokenHash; + } + catch (Exception) + { + return ""; + } + } + + internal static object GetAuthenticationContextValue(SqlConnection connection) + { + object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection); + + object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj); + + IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null); + + object authenticationContextObj = authenticationContexts.Cast().FirstOrDefault(); + + object authenticationContextValueObj = authenticationContextObj.GetType().GetProperty("Value").GetValue(authenticationContextObj, null); + + return authenticationContextValueObj; + } + } +}