Skip to content

Commit

Permalink
Print refresh token cache keys to logs (#4375)
Browse files Browse the repository at this point in the history
* Print RTs to cache.

* Add tests.
  • Loading branch information
pmaytak authored Oct 16, 2023
1 parent 8bdf699 commit a8b93d4
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ internal void InitCacheKey()
iOSCacheKeyLazy = new Lazy<IiOSKey>(() => InitiOSKey());
}

internal string ToLogString(bool piiEnabled = false)
{
return MsalCacheKeys.GetCredentialKey(
piiEnabled ? HomeAccountId : HomeAccountId?.GetHashCode().ToString(),
Environment,
StorageJsonValues.CredentialTypeRefreshToken,
ClientId,
tenantId: null,
scopes: null);
}

#region iOS
private IiOSKey InitiOSKey()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ async Task<Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem, Account>> IToke
await tokenCacheInternal.OnAfterAccessAsync(args).ConfigureAwait(false);
requestParams.RequestContext.ApiEvent.DurationInCacheInMs += sw.ElapsedMilliseconds;

DumpCacheToLogs(requestParams);
LogCacheContents(requestParams);
}

#pragma warning disable CS0618 // Type or member is obsolete
Expand Down Expand Up @@ -294,27 +294,34 @@ private bool ShouldCacheAccessToken(MsalAccessTokenCacheItem msalAccessTokenCach

//This method pulls all of the access and refresh tokens from the cache and can therefore be very impactful on performance.
//This will run on a background thread to mitigate this.
private void DumpCacheToLogs(AuthenticationRequestParameters requestParameters)
private void LogCacheContents(AuthenticationRequestParameters requestParameters)
{

if (requestParameters.RequestContext.Logger.IsLoggingEnabled(LogLevel.Verbose))
{
var accessTokenCacheItems = Accessor.GetAllAccessTokens();
var refreshTokenCacheItems = Accessor.GetAllRefreshTokens();
var accessTokenCacheItemsSubset = accessTokenCacheItems.Take(10).Select(item => item).ToList();
var accessTokenCacheItemsSubset = accessTokenCacheItems.Take(10).ToList();
var refreshTokenCacheItemsSubset = refreshTokenCacheItems.Take(10).ToList();

StringBuilder tokenCacheKeyDump = new StringBuilder();
StringBuilder tokenCacheKeyLog = new StringBuilder();

tokenCacheKeyDump.AppendLine($"Total number of access tokens in cache: {accessTokenCacheItems.Count}");
tokenCacheKeyDump.AppendLine($"Total number of refresh tokens in cache: {refreshTokenCacheItems.Count}");
tokenCacheKeyLog.AppendLine($"Total number of access tokens in the cache: {accessTokenCacheItems.Count}");
tokenCacheKeyLog.AppendLine($"Total number of refresh tokens in the cache: {refreshTokenCacheItems.Count}");

tokenCacheKeyDump.AppendLine($"Token cache dump of the first {accessTokenCacheItemsSubset.Count} cache keys.");
tokenCacheKeyLog.AppendLine($"First {accessTokenCacheItemsSubset.Count} access token cache keys:");
foreach (var cacheItem in accessTokenCacheItemsSubset)
{
tokenCacheKeyDump.AppendLine($"AT Cache Key: {cacheItem.ToLogString(requestParameters.RequestContext.Logger.PiiLoggingEnabled)}");
tokenCacheKeyLog.AppendLine($"AT Cache Key: {cacheItem.ToLogString(requestParameters.RequestContext.Logger.PiiLoggingEnabled)}");
}

requestParameters.RequestContext.Logger.Verbose(() => tokenCacheKeyDump.ToString());
tokenCacheKeyLog.AppendLine($"First {refreshTokenCacheItemsSubset.Count} refresh token cache keys:");
foreach (var cacheItem in refreshTokenCacheItemsSubset)
{
tokenCacheKeyLog.AppendLine($"RT Cache Key: {cacheItem.ToLogString(requestParameters.RequestContext.Logger.PiiLoggingEnabled)}");
}

requestParameters.RequestContext.Logger.Verbose(() => tokenCacheKeyLog.ToString());
}
}

Expand Down
108 changes: 64 additions & 44 deletions tests/Microsoft.Identity.Test.Unit/CacheTests/TokenCacheTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Castle.Core.Internal;
using Microsoft.Identity.Client;
Expand All @@ -25,7 +24,6 @@
using Microsoft.Identity.Test.Common.Core.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NSubstitute;
using NSubstitute.Extensions;

namespace Microsoft.Identity.Test.Unit.CacheTests
{
Expand Down Expand Up @@ -497,7 +495,7 @@ public void ExpiryNoTokens()
var userAccessorExpiration = TokenCache.CalculateSuggestedCacheExpiry(userTokenCache.Accessor, logger);

// Assert
Assert.IsNull(appAccessorExpiration );
Assert.IsNull(appAccessorExpiration);
Assert.IsNull(userAccessorExpiration);
Assert.IsFalse(appTokenCache.Accessor.HasAccessOrRefreshTokens());
Assert.IsFalse(userTokenCache.Accessor.HasAccessOrRefreshTokens());
Expand Down Expand Up @@ -1371,7 +1369,7 @@ public void TestIsFociMember_EnvAlias()
"1"));

// Act
bool? result = cache.IsFociMemberAsync(requestParams, "1").Result; //requst params uses ProductionPrefEnvAlias
bool? result = cache.IsFociMemberAsync(requestParams, "1").Result; //request params uses ProductionPrefEnvAlias

// Assert
Assert.AreEqual(true, result.Value);
Expand All @@ -1380,58 +1378,80 @@ public void TestIsFociMember_EnvAlias()

[TestMethod]
[TestCategory(TestCategories.TokenCacheTests)]
public async Task ValidateTokenCacheIsDumpedToLogsTestAsync()
public async Task ValidateTokenCacheContentsAreLogged_TestAsync()
{
using (MockHttpAndServiceBundle harness = CreateTestHarness())
{
//Arrange
harness.HttpManager.AddInstanceDiscoveryMockHandler();
using MockHttpAndServiceBundle harness = CreateTestHarness();

string dump = string.Empty;
LogCallback callback = (LogLevel level, string message, bool containsPii) =>
//Arrange
harness.HttpManager.AddInstanceDiscoveryMockHandler();

string logs = string.Empty;
LogCallback logCallback = (LogLevel level, string message, bool containsPii) =>
{
if (level == LogLevel.Verbose)
{
if (level == LogLevel.Verbose)
{
dump += $"MSAL Test: {message}\n";
}
};
logs += message;
}
};

var serviceBundle = TestCommon.CreateServiceBundleWithCustomHttpManager(harness.HttpManager, logCallback: callback);
ITokenCacheInternal cache = new TokenCache(serviceBundle, false);
cache.SetAfterAccess((args) => { return; });
var serviceBundle = TestCommon.CreateServiceBundleWithCustomHttpManager(harness.HttpManager, logCallback: logCallback);
ITokenCacheInternal cache = new TokenCache(serviceBundle, false);
cache.SetAfterAccess((args) => { return; });

var ex = TokenCacheHelper.PopulateCacheWithAccessTokens(cache.Accessor, 19);
TokenCacheHelper.PopulateCacheWithAccessTokens(cache.Accessor, 11);

var requestParams = TestCommon.CreateAuthenticationRequestParameters(
serviceBundle,
scopes: new HashSet<string>());
requestParams.Account = TestConstants.s_user;
requestParams.RequestContext.ApiEvent = new ApiEvent(Guid.NewGuid());
var requestParams = TestCommon.CreateAuthenticationRequestParameters(
serviceBundle,
scopes: new HashSet<string>());
requestParams.Account = TestConstants.s_user;
requestParams.RequestContext.ApiEvent = new ApiEvent(Guid.NewGuid());

var response = TokenCacheHelper.CreateMsalTokenResponse(true);
var response = TokenCacheHelper.CreateMsalTokenResponse(true);

//Act
await cache.SaveTokenResponseAsync(requestParams, response).ConfigureAwait(false);
//Act
await cache.SaveTokenResponseAsync(requestParams, response).ConfigureAwait(false);

//Assert
Assert.IsTrue(dump != string.Empty);
Assert.IsTrue(dump.Contains("Total number of access tokens in cache: 20"));
Assert.IsTrue(dump.Contains("Total number of refresh tokens in cache: 20"));
Assert.IsTrue(dump.Contains("Token cache dump of the first 10 cache keys"));

var accessTokens = cache.Accessor.GetAllAccessTokens().ToList();
foreach (MsalAccessTokenCacheItem item in accessTokens)
{
Assert.IsTrue(dump.Contains(item.ToLogString()));
if (accessTokens.IndexOf(item) >= 9)
{
break;
}
}
//Assert
Assert.IsTrue(logs != string.Empty);
Assert.IsTrue(logs.Contains("Total number of access tokens in the cache: 12"));
Assert.IsTrue(logs.Contains("Total number of refresh tokens in the cache: 12"));
Assert.IsTrue(logs.Contains("First 10 access token cache keys:"));
Assert.IsTrue(logs.Contains("First 10 refresh token cache keys:"));

var accessTokens = cache.Accessor.GetAllAccessTokens().ToList();
var refreshTokens = cache.Accessor.GetAllRefreshTokens().ToList();
for (int i = 0; i < 10; i++)
{
Assert.IsTrue(logs.Contains(accessTokens[i].ToLogString()));
Assert.IsTrue(logs.Contains(refreshTokens[i].ToLogString()));
}
}


[DataTestMethod]
[DataRow(true)]
[DataRow(false)]
public void AccessTokenCacheItem_ToLogString_UsesPiiFlag_Test(bool enablePii)
{
var accessTokenCacheItem = TokenCacheHelper.CreateAccessTokenItem();

var log = accessTokenCacheItem.ToLogString(enablePii);

Assert.AreEqual(enablePii, log.Contains(accessTokenCacheItem.HomeAccountId));
Assert.AreNotEqual(enablePii, log.Contains(accessTokenCacheItem.HomeAccountId.GetHashCode().ToString()));
}

[DataTestMethod]
[DataRow(true)]
[DataRow(false)]
public void RefreshTokenCacheItem_ToLogString_UsesPiiFlag_Test(bool enablePii)
{
var refreshTokenCacheItem = TokenCacheHelper.CreateRefreshTokenItem();

var log = refreshTokenCacheItem.ToLogString(enablePii);

Assert.AreEqual(enablePii, log.Contains(refreshTokenCacheItem.HomeAccountId));
Assert.AreNotEqual(enablePii, log.Contains(refreshTokenCacheItem.HomeAccountId.GetHashCode().ToString()));
}

private void ValidateIsFociMember(
ITokenCacheInternal cache,
Expand Down

0 comments on commit a8b93d4

Please sign in to comment.