Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Fix driver to not send expired token and refresh token first before sending it. #2273

Merged
merged 15 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");

if (_newDbConnectionPoolAuthenticationContext != null)
{
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
else if (!attemptRefreshTokenLocked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2680,6 +2680,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");

if (_newDbConnectionPoolAuthenticationContext != null)
{
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
else if (!attemptRefreshTokenLocked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
<Compile Include="ProviderAgnostic\MultipleResultsTest\MultipleResultsTest.cs" />
<Compile Include="ProviderAgnostic\ReaderTest\ReaderTest.cs" />
<Compile Include="TracingTests\EventSourceTest.cs" />
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
<Compile Include="SQL\ConnectionPoolTest\ConnectionPoolTest.cs" />
<Compile Include="SQL\ConnectionPoolTest\PoolBlockPeriodTest.cs" />
<Compile Include="SQL\InstanceNameTest\InstanceNameTest.cs" />
Expand Down Expand Up @@ -275,6 +276,7 @@
<Compile Include="SQL\Common\SystemDataInternals\ConnectionHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\ConnectionPoolHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\DataReaderHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\FedAuthTokenHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTest.cs" />
Expand All @@ -298,30 +300,19 @@
<ProjectReference Include="$(TestsPath)tools\TDS\TDS.EndPoint\TDS.EndPoint.csproj" />
<ProjectReference Include="$(TestsPath)tools\TDS\TDS.Servers\TDS.Servers.csproj" />
<ProjectReference Include="$(TestsPath)tools\TDS\TDS\TDS.csproj" />
<ProjectReference
Include="$(TestsPath)tools\Microsoft.Data.SqlClient.TestUtilities\Microsoft.Data.SqlClient.TestUtilities.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netcoreapp' AND $(ReferenceType)=='Project'"
Include="$(NetCoreSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netfx' AND $(ReferenceType)=='Project'"
Include="$(NetFxSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="$(ReferenceType.Contains('NetStandard'))"
Include="$(TestsPath)NSLibrary\Microsoft.Data.SqlClient.NSLibrary.csproj" />
<ProjectReference Condition="!$(ReferenceType.Contains('Package'))"
Include="$(SqlServerSource)Microsoft.SqlServer.Server.csproj" />
<PackageReference Condition="$(ReferenceType.Contains('Package'))"
Include="Microsoft.Data.SqlClient" Version="$(TestMicrosoftDataSqlClientVersion)" />
<ProjectReference
Include="$(TestsPath)CustomConfigurableRetryLogic\CustomRetryLogicProvider.csproj" />
<ProjectReference Include="$(TestsPath)tools\Microsoft.Data.SqlClient.TestUtilities\Microsoft.Data.SqlClient.TestUtilities.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netcoreapp' AND $(ReferenceType)=='Project'" Include="$(NetCoreSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netfx' AND $(ReferenceType)=='Project'" Include="$(NetFxSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="$(ReferenceType.Contains('NetStandard'))" Include="$(TestsPath)NSLibrary\Microsoft.Data.SqlClient.NSLibrary.csproj" />
<ProjectReference Condition="!$(ReferenceType.Contains('Package'))" Include="$(SqlServerSource)Microsoft.SqlServer.Server.csproj" />
<PackageReference Condition="$(ReferenceType.Contains('Package'))" Include="Microsoft.Data.SqlClient" Version="$(TestMicrosoftDataSqlClientVersion)" />
<ProjectReference Include="$(TestsPath)CustomConfigurableRetryLogic\CustomRetryLogicProvider.csproj" />
</ItemGroup>
<!-- XUnit and XUnit extensions -->
<ItemGroup>
<PackageReference Condition="$(TargetGroup) == 'netfx'"
Include="System.Runtime.InteropServices.RuntimeInformation"
Version="$(SystemRuntimeInteropServicesRuntimeInformationVersion)" />
<PackageReference Condition="$(TargetGroup) == 'netfx'" Include="System.Runtime.InteropServices.RuntimeInformation" Version="$(SystemRuntimeInteropServicesRuntimeInformationVersion)" />
<PackageReference Include="xunit" Version="$(XunitVersion)" />
<PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies"
Version="$(MicrosoftNETFrameworkReferenceAssembliesVersion)"
Condition="'$(TargetGroup)' == 'netfx'">
<PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies" Version="$(MicrosoftNETFrameworkReferenceAssembliesVersion)" Condition="'$(TargetGroup)' == 'netfx'">
arellegue marked this conversation as resolved.
Show resolved Hide resolved
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
Expand All @@ -334,8 +325,7 @@
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="xunit.runner.utility" Version="$(XunitVersion)" />
<PackageReference Include="Microsoft.DotNet.XUnitExtensions"
Version="$(MicrosoftDotNetXUnitExtensionsVersion)" />
<PackageReference Include="Microsoft.DotNet.XUnitExtensions" Version="$(MicrosoftDotNetXUnitExtensionsVersion)" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonVersion)" />
Expand All @@ -352,7 +342,7 @@
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="$(SystemIdentityModelTokensJwtVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netfx'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersionNet)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
<PackageReference Condition="'$(TargetGroup)'!='netfx'" Include="System.ServiceProcess.ServiceController" Version="$(SystemServiceProcessServiceControllerVersion)" />
</ItemGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
DavoudEshtehari marked this conversation as resolved.
Show resolved Hide resolved
using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public class AADFedAuthTokenRefreshTest
{
private readonly ITestOutputHelper _testOutputHelper;

public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
public void FedAuthTokenRefreshTest()
{
string connectionString = DataTestUtility.AADPasswordConnectionString;

using (SqlConnection connection = new SqlConnection(connectionString))
{
connection.Open();

string oldTokenHash = "";
DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash);
Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute.");

// Convert and display the old expiry into local time which should be in 1 minute from now
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local);
LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}");
TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
Assert.InRange(timeDiff.TotalSeconds, 0, 60);

// Check if connection is still alive to continue further testing
string result = "";
SqlCommand cmd = connection.CreateCommand();
cmd.CommandText = "select @@version";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value");

// The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
using (SqlConnection connection2 = new SqlConnection(connectionString))
{
connection2.Open();

// Check if connection is alive
cmd = connection2.CreateCommand();
cmd.CommandText = "select 1";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");

string newTokenHash = "";
DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash);
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local);
LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}");

Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical.");
Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");
}
}
}

private void LogInfo(string message)
{
_testOutputHelper.WriteLine(message);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
DavoudEshtehari marked this conversation as resolved.
Show resolved Hide resolved
using System.Collections;
using System.Linq;
using System.Reflection;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals
{
internal static class FedAuthTokenHelper
{
internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash)
{
try
{
object authenticationContextValueObj = GetAuthenticationContextValue(connection);

DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

tokenHash = GetTokenHash(authenticationContextValueObj);

return expirationTimeProperty;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash)
{
try
{
object authenticationContextValueObj = GetAuthenticationContextValue(connection);

DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire);

FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty);

tokenHash = GetTokenHash(authenticationContextValueObj);

return expirationTimeProperty;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

internal static string GetTokenHash(object authenticationContextValueObj)
{
try
{
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));

Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken");

Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType };

ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper");

ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { });

MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null);

byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { });
FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance);
accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes);

string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj });

return tokenHash;
}
catch (Exception)
{
return "";
}
}

internal static object GetAuthenticationContextValue(SqlConnection connection)
{
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);

object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);

IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null);

object authenticationContextObj = authenticationContexts.Cast<object>().FirstOrDefault();

object authenticationContextValueObj = authenticationContextObj.GetType().GetProperty("Value").GetValue(authenticationContextObj, null);

return authenticationContextValueObj;
}
}
}
Loading