diff --git a/buildpackLocal.bat b/buildpackLocal.bat new file mode 100644 index 0000000000..44038b4f98 --- /dev/null +++ b/buildpackLocal.bat @@ -0,0 +1,3 @@ +dotnet clean Product.proj > clean.log +dotnet build /r Product.proj +dotnet pack --no-restore -o c:\localpackages --no-build Product.proj diff --git a/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs index 6183e64781..d03d3efc69 100644 --- a/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs @@ -208,15 +208,13 @@ public override bool Sign(ReadOnlySpan input, Span signature, out in catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } - } #endif @@ -248,12 +246,11 @@ public override byte[] Sign(byte[] input) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } } @@ -279,12 +276,11 @@ public override byte[] Sign(byte[] input, int offset, int count) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } } @@ -380,12 +376,11 @@ public override bool Verify(byte[] input, byte[] signature) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } } @@ -474,15 +469,14 @@ public override bool Verify(byte[] input, int inputOffset, int inputLength, byte } catch { - Dispose(true); + CryptoProviderCache?.TryRemove(this); throw; } finally { - if (!_disposed) + if (asym != null) _asymmetricAdapterObjectPool.Free(asym); } - } /// diff --git a/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs b/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs index a94551be69..d3d7213170 100644 --- a/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs +++ b/src/Microsoft.IdentityModel.Tokens/CryptoProviderFactory.cs @@ -18,6 +18,8 @@ public class CryptoProviderFactory private static readonly ConcurrentDictionary _typeToAlgorithmMap = new ConcurrentDictionary(); private static readonly object _cacheLock = new object(); private static int _defaultSignatureProviderObjectPoolCacheSize = Environment.ProcessorCount * 4; + private static string _typeofAsymmetricSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + private static string _typeofSymmetricSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); private int _signatureProviderObjectPoolCacheSize = _defaultSignatureProviderObjectPoolCacheSize; /// @@ -513,7 +515,13 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori { signatureProvider = CustomCryptoProvider.Create(algorithm, key, willCreateSignatures) as SignatureProvider; if (signatureProvider == null) - throw LogHelper.LogExceptionMessage(new InvalidOperationException(LogHelper.FormatInvariant(LogMessages.IDX10646, LogHelper.MarkAsNonPII(algorithm), key, LogHelper.MarkAsNonPII(typeof(SignatureProvider))))); + throw LogHelper.LogExceptionMessage( + new InvalidOperationException( + LogHelper.FormatInvariant( + LogMessages.IDX10646, + LogHelper.MarkAsNonPII(algorithm), + key, + LogHelper.MarkAsNonPII(typeof(SignatureProvider))))); return signatureProvider; } @@ -523,7 +531,7 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori bool createAsymmetric = true; if (key is AsymmetricSecurityKey) { - typeofSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofAsymmetricSignatureProvider; } else if (key is JsonWebKey jsonWebKey) { @@ -533,11 +541,11 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori { if (convertedSecurityKey is AsymmetricSecurityKey) { - typeofSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofAsymmetricSignatureProvider; } else if (convertedSecurityKey is SymmetricSecurityKey) { - typeofSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofSymmetricSignatureProvider; createAsymmetric = false; } } @@ -545,10 +553,10 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori else { if (jsonWebKey.Kty == JsonWebAlgorithmsKeyTypes.RSA || jsonWebKey.Kty == JsonWebAlgorithmsKeyTypes.EllipticCurve) - typeofSignatureProvider = typeof(AsymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofAsymmetricSignatureProvider; else if (jsonWebKey.Kty == JsonWebAlgorithmsKeyTypes.Octet) { - typeofSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofSymmetricSignatureProvider; createAsymmetric = false; } } @@ -560,12 +568,20 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori } else if (key is SymmetricSecurityKey) { - typeofSignatureProvider = typeof(SymmetricSignatureProvider).ToString(); + typeofSignatureProvider = _typeofSymmetricSignatureProvider; createAsymmetric = false; } if (typeofSignatureProvider == null) - throw LogHelper.LogExceptionMessage(new NotSupportedException(LogHelper.FormatInvariant(LogMessages.IDX10621, LogHelper.MarkAsNonPII(typeof(SymmetricSignatureProvider)), LogHelper.MarkAsNonPII(typeof(SecurityKey)), LogHelper.MarkAsNonPII(typeof(AsymmetricSecurityKey)), LogHelper.MarkAsNonPII(typeof(SymmetricSecurityKey)), LogHelper.MarkAsNonPII(key.GetType())))); + throw LogHelper.LogExceptionMessage( + new NotSupportedException( + LogHelper.FormatInvariant( + LogMessages.IDX10621, + LogHelper.MarkAsNonPII(typeof(SymmetricSignatureProvider)), + LogHelper.MarkAsNonPII(typeof(SecurityKey)), + LogHelper.MarkAsNonPII(typeof(AsymmetricSecurityKey)), + LogHelper.MarkAsNonPII(typeof(SymmetricSecurityKey)), + LogHelper.MarkAsNonPII(key.GetType())))); if (CacheSignatureProviders && cacheProvider) { @@ -592,7 +608,7 @@ private SignatureProvider CreateSignatureProvider(SecurityKey key, string algori signatureProvider = new SymmetricSignatureProvider(key, algorithm, willCreateSignatures); if (ShouldCacheSignatureProvider(signatureProvider)) - CryptoProviderCache.TryAdd(signatureProvider); + signatureProvider.IsCached = CryptoProviderCache.TryAdd(signatureProvider); } } else @@ -737,7 +753,7 @@ public virtual void ReleaseSignatureProvider(SignatureProvider signatureProvider signatureProvider.Release(); if (CustomCryptoProvider != null && CustomCryptoProvider.IsSupportedAlgorithm(signatureProvider.Algorithm)) CustomCryptoProvider.Release(signatureProvider); - else if (signatureProvider.CryptoProviderCache == null && signatureProvider.RefCount == 0) + else if (signatureProvider.CryptoProviderCache == null && signatureProvider.RefCount == 0 && !signatureProvider.IsCached) signatureProvider.Dispose(); } } diff --git a/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs b/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs index 1a329252ee..64e3a70fed 100644 --- a/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs +++ b/src/Microsoft.IdentityModel.Tokens/EventBasedLRUCache.cs @@ -5,7 +5,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; -using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft.IdentityModel.Abstractions; @@ -29,63 +28,74 @@ namespace Microsoft.IdentityModel.Tokens /// The value type to be used by the cache internal class EventBasedLRUCache { + internal delegate void ItemCompacted(TValue Value); + internal delegate void ItemExpired(TValue Value); internal delegate void ItemRemoved(TValue Value); + internal delegate bool ShouldRemove(TValue Value); private readonly int _capacity; - + private List> _compactedItems = new List>(); // The percentage of the cache to be removed when _maxCapacityPercentage is reached. private readonly double _compactionPercentage = .20; private LinkedList> _doubleLinkedList = new LinkedList>(); private ConcurrentQueue _eventQueue = new ConcurrentQueue(); + private readonly TaskCreationOptions _options; + // if true, then items will be maintained in a LRU fashion, moving to front of list when accessed in the cache. + private readonly bool _maintainLRU; private ConcurrentDictionary> _map; - // When the current cache size gets to this percentage of _capacity, _compactionPercentage% of the cache will be removed. private readonly double _maxCapacityPercentage = .95; + private readonly int _compactIntervalInSeconds; // if true, expired values will not be added to the cache and clean-up of expired values will occur on a 5 minute interval private readonly bool _removeExpiredValues; private readonly int _removeExpiredValuesIntervalInSeconds; - // if true, then items will be maintained in a LRU fashion, moving to front of list when accessed in the cache. - private readonly bool _maintainLRU; - - private readonly TaskCreationOptions _options; - private DateTime _dueForExpiredValuesRemoval; - // for testing purpose only to verify the task count private int _taskCount = 0; + private DateTime _timeForNextExpiredValuesRemoval; + private DateTime _timeForNextCompaction; #region event queue - private int _eventQueuePollingInterval = 50; - // The idle timeout, the _eventQueueTask will end after being idle for the specified time interval (execution continues even if the queue is empty to reduce the task startup overhead), default to 120 seconds. // TODO: consider implementing a better algorithm that tracks and predicts the usage patterns and adjusts this value dynamically. private long _eventQueueTaskIdleTimeoutInSeconds = 120; - // The time when the _eventQueueTask should end. The intent is to reduce the overhead costs of starting/ending tasks too frequently // but at the same time keep the _eventQueueTask a short running task. // Since Task is based on thread pool the overhead should be reasonable. private DateTime _eventQueueTaskStopTime; - // task states used to ensure thread safety (Interlocked.CompareExchange) private const int EventQueueTaskStopped = 0; // task not started yet private const int EventQueueTaskRunning = 1; // task is running private const int EventQueueTaskDoNotStop = 2; // force the task to continue even it has past the _eventQueueTaskStopTime, see StartEventQueueTaskIfNotRunning() for more details. private int _eventQueueTaskState = EventQueueTaskStopped; - private const int CompactionNotQueued = 0; // compaction action not in the event queue - private const int CompactionQueuedOrRunning = 1; // compaction action in the event queue or currently in progress - private int _compactionState = CompactionNotQueued; + private const int ActionNotQueued = 0; // compaction action not in the event queue + private const int ActionQueuedOrRunning = 1; // compaction action in the event queue or currently in progress + private int _compactValuesState = ActionNotQueued; + private int _removeExpiredValuesState = ActionNotQueued; + private int _processCompactedValuesState = ActionNotQueued; // set to true when the AppDomain is to be unloaded or the default AppDomain process is ready to exit private bool _shouldStopImmediately = false; - internal ItemRemoved OnItemRemoved + internal ItemExpired OnItemExpired { get; set; } + + /// + /// For back compat any friend would be broken, this is the same as OnItemExpired. + /// + internal ItemExpired OnItemRemoved { - get; - set; + get { return OnItemExpired; } + set { OnItemExpired = value; } } + internal ItemCompacted OnItemMovedToCompactedList { get; set; } + + internal ItemRemoved OnItemRemovedFromCompactedList { get; set; } + + internal ShouldRemove OnShouldRemoveFromCompactedList { get; set; } + internal long EventQueueTaskIdleTimeoutInSeconds { get => _eventQueueTaskIdleTimeoutInSeconds; @@ -96,20 +106,6 @@ internal long EventQueueTaskIdleTimeoutInSeconds _eventQueueTaskIdleTimeoutInSeconds = value; } } - - // If the task operating on the _eventQueue has not timed out and the _eventQueue is empty, this polling interval will be used - // to determine how often the cache should be checked for the presence of a new action. - private int EventQueuePollingInterval - { - get => _eventQueuePollingInterval; - set - { - if (value <= 0) - throw new ArgumentOutOfRangeException(nameof(value), "EventQueuePollingInterval must be positive."); - _eventQueuePollingInterval = value; - } - } - #endregion /// @@ -121,22 +117,26 @@ private int EventQueuePollingInterval /// Whether or not to remove expired items. /// The period to wait to remove expired items, in seconds. /// Whether or not to maintain items in a LRU fashion, moving to front of list when accessed in the cache. + /// The period to wait to compact items, in seconds. internal EventBasedLRUCache( int capacity, TaskCreationOptions options = TaskCreationOptions.None, IEqualityComparer comparer = null, bool removeExpiredValues = false, int removeExpiredValuesIntervalInSeconds = 300, - bool maintainLRU = false) + bool maintainLRU = false, + int compactIntervalInSeconds = 20) { _capacity = capacity > 0 ? capacity : throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException(nameof(capacity))); _options = options; _map = new ConcurrentDictionary>(comparer ?? EqualityComparer.Default); _removeExpiredValuesIntervalInSeconds = removeExpiredValuesIntervalInSeconds; _removeExpiredValues = removeExpiredValues; + _compactIntervalInSeconds = compactIntervalInSeconds; + _timeForNextExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); + _timeForNextCompaction = DateTime.UtcNow.AddSeconds(_compactIntervalInSeconds); _eventQueueTaskStopTime = DateTime.UtcNow; _maintainLRU = maintainLRU; - _dueForExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); } /// @@ -168,6 +168,7 @@ internal EventBasedLRUCache( private void AddActionToEventQueue(Action action) { _eventQueue.Enqueue(action); + // start the event queue task if it is not running StartEventQueueTaskIfNotRunning(); } @@ -186,88 +187,97 @@ public bool Contains(TKey key) private void EventQueueTaskAction() { Interlocked.Increment(ref _taskCount); - // Keep running until the queue is empty or the AppDomain is about to be unloaded or the application is ready to exit. - while (!_shouldStopImmediately) + try { - // always set the state to EventQueueTaskRunning in case it was set to EventQueueTaskDoNotStop - Interlocked.Exchange(ref _eventQueueTaskState, EventQueueTaskRunning); - - try + // Keep running until the queue is empty or the AppDomain is about to be unloaded or the application is ready to exit. + while (!_shouldStopImmediately) { - // remove expired items if needed - if (_removeExpiredValues && DateTime.UtcNow >= _dueForExpiredValuesRemoval) - { - if (_maintainLRU) - RemoveExpiredValuesLRU(); - else - RemoveExpiredValues(); + // always set the state to EventQueueTaskRunning in case it was set to EventQueueTaskDoNotStop + Interlocked.Exchange(ref _eventQueueTaskState, EventQueueTaskRunning); - _dueForExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); - } - - // process all events in the queue and exit - if (_eventQueue.TryDequeue(out var action)) - { - action?.Invoke(); - } - else if (DateTime.UtcNow > _eventQueueTaskStopTime) // no more event to be processed, exit if expired + try { - // Setting _eventQueueTaskState = EventQueueTaskStopped if the _eventQueueTaskEndTime has past and _eventQueueTaskState == EventQueueTaskRunning. - // This means no other thread came in and it is safe to end this task. - // If another thread adds new events while this task is still running, it will set the _eventQueueTaskState = EventQueueTaskDoNotStop instead of starting a new task. - // The Interlocked.CompareExchange() call below will not succeed and the loop continues (until the event queue is empty and the _eventQueueTaskEndTime expires again). - // This should prevent a rare (but theoretically possible) scenario caused by context switching. - if (Interlocked.CompareExchange(ref _eventQueueTaskState, EventQueueTaskStopped, EventQueueTaskRunning) == EventQueueTaskRunning) - break; - + // remove expired items if needed + if (_removeExpiredValues && DateTime.UtcNow >= _timeForNextExpiredValuesRemoval) + { + if (Interlocked.CompareExchange(ref _removeExpiredValuesState, ActionNotQueued, ActionQueuedOrRunning) == ActionQueuedOrRunning) + { + if (_maintainLRU) + RemoveExpiredValuesLRU(); + else + RemoveExpiredValues(); + } + } + + // process all events in the queue and exit + if (_eventQueue.TryDequeue(out var action)) + { + action?.Invoke(); + } + else if (DateTime.UtcNow > _eventQueueTaskStopTime) // no more event to be processed, exit if expired + { + // Setting _eventQueueTaskState = EventQueueTaskStopped if the _eventQueueStopTime has past and _eventQueueTaskState == EventQueueTaskRunning. + // This means no other thread came in and it is safe to end this task. + // If another thread adds new events while this task is still running, it will set the _eventQueueTaskState = EventQueueTaskDoNotStop instead of starting a new task. + // The Interlocked.CompareExchange() call below will not succeed and the loop continues (until the event queue is empty and the _eventQueueTaskEndTime expires again). + // This should prevent a rare (but theoretically possible) scenario caused by context switching. + if (Interlocked.CompareExchange(ref _eventQueueTaskState, EventQueueTaskStopped, EventQueueTaskRunning) == EventQueueTaskRunning) + break; + + } + else // if empty, let the thread sleep for a specified number of milliseconds before attempting to retrieve another value from the queue + { + Thread.Sleep(_eventQueuePollingInterval); + } } - else // if empty, let the thread sleep for a specified number of milliseconds before attempting to retrieve another value from the queue + catch (Exception ex) { - Thread.Sleep(_eventQueuePollingInterval); + if (LogHelper.IsEnabled(EventLogLevel.Warning)) + LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10900, ex)); } } - catch (Exception ex) - { - if (LogHelper.IsEnabled(EventLogLevel.Warning)) - LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10900, ex)); - } } - - Interlocked.Decrement(ref _taskCount); + finally + { + Interlocked.Decrement(ref _taskCount); + Interlocked.Exchange(ref _eventQueueTaskState, EventQueueTaskStopped); + } } /// /// Remove all expired cache items from _doubleLinkedList and _map. /// /// Number of items removed. - internal int RemoveExpiredValuesLRU() + internal void RemoveExpiredValuesLRU() { - int numItemsRemoved = 0; +#pragma warning disable CA1031 // Do not catch general exception types try { - var node = _doubleLinkedList.First; + LinkedListNode> node = _doubleLinkedList.First; while (node != null) { - var nextNode = node.Next; + LinkedListNode> nextNode = node.Next; if (node.Value.ExpirationTime < DateTime.UtcNow) { _doubleLinkedList.Remove(node); - if (_map.TryRemove(node.Value.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); - - numItemsRemoved++; + if (_map.TryRemove(node.Value.Key, out LRUCacheItem cacheItem)) + OnItemExpired?.Invoke(cacheItem.Value); } node = nextNode; } } - catch (ObjectDisposedException ex) + catch(Exception ex) { if (LogHelper.IsEnabled(EventLogLevel.Warning)) LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10902, LogHelper.MarkAsNonPII(nameof(RemoveExpiredValuesLRU)), ex)); } - - return numItemsRemoved; + finally + { + _removeExpiredValuesState = ActionNotQueued; + _timeForNextExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); + } +#pragma warning restore CA1031 // Do not catch general exception types } /// @@ -275,29 +285,64 @@ internal int RemoveExpiredValuesLRU() /// The enumerator returned from the dictionary is safe to use concurrently with reads and writes to the dictionary, according to the MS document. /// /// Number of items removed. - internal int RemoveExpiredValues() + internal void RemoveExpiredValues() { - int numItemsRemoved = 0; +#pragma warning disable CA1031 // Do not catch general exception types try { - foreach (var node in _map) + foreach (KeyValuePair> node in _map) { if (node.Value.ExpirationTime < DateTime.UtcNow) { if (_map.TryRemove(node.Value.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); + OnItemExpired?.Invoke(cacheItem.Value); + } + } + } + catch(Exception ex) + { + if (LogHelper.IsEnabled(EventLogLevel.Warning)) + LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10902, LogHelper.MarkAsNonPII(nameof(ProcessCompactedValues)), ex)); + + } + finally + { + _removeExpiredValuesState = ActionNotQueued; + _timeForNextExpiredValuesRemoval = DateTime.UtcNow.AddSeconds(_removeExpiredValuesIntervalInSeconds); + } - numItemsRemoved++; +#pragma warning restore CA1031 // Do not catch general exception types + } + + /// + /// Remove all compacted items. + /// + internal void ProcessCompactedValues() + { +#pragma warning disable CA1031 // Do not catch general exception types + try + { + for (int i = _compactedItems.Count - 1; i >= 0; i--) + { + if ((OnShouldRemoveFromCompactedList == null) || OnShouldRemoveFromCompactedList(_compactedItems[i].Value)) + { + OnItemRemovedFromCompactedList?.Invoke(_compactedItems[i].Value); + _compactedItems.RemoveAt(i); } } } - catch (ObjectDisposedException ex) + catch(Exception ex) { if (LogHelper.IsEnabled(EventLogLevel.Warning)) - LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10902, LogHelper.MarkAsNonPII(nameof(RemoveExpiredValues)), ex)); + LogHelper.LogWarning(LogHelper.FormatInvariant(LogMessages.IDX10906, LogHelper.MarkAsNonPII(nameof(ProcessCompactedValues)), ex)); + } + finally + { + _processCompactedValuesState = ActionNotQueued; + _timeForNextCompaction = DateTime.UtcNow.AddSeconds(_compactIntervalInSeconds); } - return numItemsRemoved; +#pragma warning restore CA1031 // Do not catch general exception types } /// @@ -306,18 +351,23 @@ internal int RemoveExpiredValues() /// private void CompactLRU() { - var newCacheSize = CalculateNewCacheSize(); - while (_map.Count > newCacheSize && _doubleLinkedList.Count > 0) + try { - var lru = _doubleLinkedList.Last; - if (_map.TryRemove(lru.Value.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); + int newCacheSize = CalculateNewCacheSize(); + while (_map.Count > newCacheSize && _doubleLinkedList.Count > 0) + { + LinkedListNode> node = _doubleLinkedList.Last; + if (_map.TryRemove(node.Value.Key, out LRUCacheItem cacheItem)) + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); - _doubleLinkedList.RemoveLast(); + _compactedItems.Add(cacheItem); + _doubleLinkedList.RemoveLast(); + } + } + finally + { + _compactValuesState = ActionNotQueued; } - - // reset _compactionState so the compaction action can be queued again when needed - _compactionState = CompactionNotQueued; } /// @@ -326,21 +376,28 @@ private void CompactLRU() /// private void Compact() { - var newCacheSize = CalculateNewCacheSize(); - while (_map.Count > newCacheSize) + try { - // Since all items could have been removed by the public TryRemove() method, leaving the map empty, we need to check if a default value is returned. - // Remove the item from the map only if the returned item is NOT default value. - var item = _map.FirstOrDefault(); - if (!item.Equals(default)) + int newCacheSize = CalculateNewCacheSize(); + while (_map.Count > newCacheSize) { - if (_map.TryRemove(item.Key, out var cacheItem)) - OnItemRemoved?.Invoke(cacheItem.Value); + // Since all items could have been removed by the public TryRemove() method, leaving the map empty, we need to check if a default value is returned. + // Remove the item from the map only if the returned item is NOT default value. + KeyValuePair> item = _map.FirstOrDefault(); + if (!item.Equals(default)) + { + if (_map.TryRemove(item.Key, out LRUCacheItem cacheItem)) + { + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); + _compactedItems.Add(cacheItem); + } + } } } - - // reset _compactionState so the compaction action can be queued again when needed - _compactionState = CompactionNotQueued; + finally + { + _compactValuesState = ActionNotQueued; + } } /// @@ -408,12 +465,20 @@ public bool SetValue(TKey key, TValue value, DateTime expirationTime) // if cache is at _maxCapacityPercentage, trim it by _compactionPercentage if ((double)_map.Count / _capacity >= _maxCapacityPercentage) { - if (Interlocked.CompareExchange(ref _compactionState, CompactionQueuedOrRunning, CompactionNotQueued) == CompactionNotQueued) + if (Interlocked.CompareExchange(ref _compactValuesState, ActionQueuedOrRunning, ActionNotQueued) == ActionNotQueued) { if (_maintainLRU) AddActionToEventQueue(CompactLRU); else AddActionToEventQueue(Compact); + + if (DateTime.UtcNow >= _timeForNextCompaction) + { + if (Interlocked.CompareExchange(ref _processCompactedValuesState, ActionQueuedOrRunning, ActionNotQueued) == ActionNotQueued) + { + _eventQueue.Enqueue(ProcessCompactedValues); + } + } } } @@ -476,8 +541,7 @@ private void StartEventQueueTaskIfNotRunning() // the caller's TaskScheduler (if there is one) as some custom TaskSchedulers might be single-threaded and its execution can be blocked. if (Interlocked.CompareExchange(ref _eventQueueTaskState, EventQueueTaskRunning, EventQueueTaskStopped) == EventQueueTaskStopped) { - // EventQueueTaskAction manages its own state. - _ = Task.Run(EventQueueTaskAction); + _ = Task.Run(EventQueueTaskAction); } } @@ -514,6 +578,23 @@ public bool TryGetValue(TKey key, out TValue value) return cacheItem != null; } + // These Try methods are not thread safe and they rely on the SignatureProviders to have logic to dispose of important objects. + // A better design would be to have TryRemove move the SignatureProvider to the compacted list. + // This would need a new action in LRUCache, AddItemToCompactedList. + + /// Removes a particular key from the cache. + public bool TryRemove(TKey key) + { + if (key == null) + throw LogHelper.LogArgumentNullException(nameof(key)); + + if (!_map.TryRemove(key, out var cacheItem)) + return false; + + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); + return true; + } + /// Removes a particular key from the cache. public bool TryRemove(TKey key, out TValue value) { @@ -534,7 +615,7 @@ public bool TryRemove(TKey key, out TValue value) } value = cacheItem.Value; - OnItemRemoved?.Invoke(cacheItem.Value); + OnItemMovedToCompactedList?.Invoke(cacheItem.Value); return true; } @@ -579,7 +660,9 @@ public bool TryRemove(TKey key, out TValue value) /// internal void WaitForProcessing() { - while (!_eventQueue.IsEmpty); + while (!_eventQueue.IsEmpty) + { + }; } #endregion @@ -613,4 +696,3 @@ public override bool Equals(object obj) public override int GetHashCode() => 990326508 + EqualityComparer.Default.GetHashCode(Key); } } - diff --git a/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs b/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs index 00ee2c023f..0ac44300f8 100644 --- a/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs +++ b/src/Microsoft.IdentityModel.Tokens/InMemoryCryptoProviderCache.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Globalization; using System.Threading.Tasks; using Microsoft.IdentityModel.Abstractions; using Microsoft.IdentityModel.Logging; @@ -14,7 +13,6 @@ namespace Microsoft.IdentityModel.Tokens /// Current support is limited to only. /// public class InMemoryCryptoProviderCache: CryptoProviderCache, IDisposable - { internal CryptoProviderCacheOptions _cryptoProviderCacheOptions; private bool _disposed = false; @@ -28,39 +26,56 @@ public InMemoryCryptoProviderCache() : this(new CryptoProviderCacheOptions()) { } - internal CryptoProviderFactory CryptoProviderFactory { get; set; } - /// /// Creates a new instance of using the specified . /// /// The options used to configure the . - public InMemoryCryptoProviderCache(CryptoProviderCacheOptions cryptoProviderCacheOptions) + public InMemoryCryptoProviderCache(CryptoProviderCacheOptions cryptoProviderCacheOptions) : this(cryptoProviderCacheOptions, TaskCreationOptions.None) { - if (cryptoProviderCacheOptions == null) - throw LogHelper.LogArgumentNullException(nameof(cryptoProviderCacheOptions)); - - _cryptoProviderCacheOptions = cryptoProviderCacheOptions; - _signingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, removeExpiredValues: false, comparer: StringComparer.Ordinal) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; - _verifyingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, removeExpiredValues: false, comparer: StringComparer.Ordinal) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; } - /// - /// Creates a new instance of using the specified . - /// - /// The options used to configure the . - /// Options used to create the event queue thread. - /// The time used in ms for the timeout interval of the event queue. Defaults to 500 ms. internal InMemoryCryptoProviderCache(CryptoProviderCacheOptions cryptoProviderCacheOptions, TaskCreationOptions options, int tryTakeTimeout = 500) { - if (cryptoProviderCacheOptions == null) - throw LogHelper.LogArgumentNullException(nameof(cryptoProviderCacheOptions)); - + _cryptoProviderCacheOptions = cryptoProviderCacheOptions ?? throw LogHelper.LogArgumentNullException(nameof(cryptoProviderCacheOptions)); if (tryTakeTimeout <= 0) throw LogHelper.LogArgumentException(nameof(tryTakeTimeout), $"{nameof(tryTakeTimeout)} must be greater than zero"); - _cryptoProviderCacheOptions = cryptoProviderCacheOptions; - _signingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, options, StringComparer.Ordinal, false) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; - _verifyingSignatureProviders = new EventBasedLRUCache(cryptoProviderCacheOptions.SizeLimit, options, StringComparer.Ordinal, false) { OnItemRemoved = (SignatureProvider signatureProvider) => signatureProvider.CryptoProviderCache = null }; + _signingSignatureProviders = new EventBasedLRUCache( + cryptoProviderCacheOptions.SizeLimit, + options, + comparer: StringComparer.Ordinal) + { + OnItemMovedToCompactedList = SetCryptoProviderCacheToNull, + OnItemRemovedFromCompactedList = DisposeSignatureProvider, + OnShouldRemoveFromCompactedList = IsCacheNullAndRefCountZero + }; + + _verifyingSignatureProviders = new EventBasedLRUCache( + cryptoProviderCacheOptions.SizeLimit, + options, + comparer: StringComparer.Ordinal) + { + OnItemMovedToCompactedList = SetCryptoProviderCacheToNull, + OnItemRemovedFromCompactedList = DisposeSignatureProvider, + OnShouldRemoveFromCompactedList = IsCacheNullAndRefCountZero + }; + } + + internal CryptoProviderFactory CryptoProviderFactory { get; set; } + + private static void DisposeSignatureProvider(SignatureProvider signatureProvider) + { + signatureProvider.Dispose(); + } + + private void SetCryptoProviderCacheToNull(SignatureProvider signatureProvider) + { + signatureProvider.CryptoProviderCache = null; + } + + private static bool IsCacheNullAndRefCountZero(SignatureProvider signatureProvider) + { + return signatureProvider.CryptoProviderCache == null && signatureProvider.RefCount == 0; } /// @@ -195,7 +210,7 @@ public override bool TryRemove(SignatureProvider signatureProvider) try { - return signatureProviderCache.TryRemove(cacheKey, out SignatureProvider provider); + return signatureProviderCache.TryRemove(cacheKey); } catch (Exception ex) { diff --git a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs index c644c0936f..94f41493d8 100644 --- a/src/Microsoft.IdentityModel.Tokens/LogMessages.cs +++ b/src/Microsoft.IdentityModel.Tokens/LogMessages.cs @@ -149,7 +149,7 @@ internal static class LogMessages public const string IDX10640 = "IDX10640: Algorithm is not supported: '{0}'."; // public const string IDX10641 = "IDX10641:"; public const string IDX10642 = "IDX10642: Creating signature using the input: '{0}'."; - public const string IDX10643 = "IDX10643: Comparing the signature created over the input with the token signature: '{0}'."; + // public const string IDX10643 = "IDX10643:"; // public const string IDX10644 = "IDX10644:"; public const string IDX10645 = "IDX10645: Elliptical Curve not supported for curveId: '{0}'"; public const string IDX10646 = "IDX10646: A CustomCryptoProvider was set and returned 'true' for IsSupportedAlgorithm(Algorithm: '{0}', Key: '{1}'), but Create.(algorithm, args) as '{2}' == NULL."; @@ -253,7 +253,8 @@ internal static class LogMessages //EventBasedLRUCache errors public const string IDX10900 = "IDX10900: EventBasedLRUCache._eventQueue encountered an error while processing a cache operation. Exception '{0}'."; public const string IDX10901 = "IDX10901: CryptoProviderCacheOptions.SizeLimit must be greater than 10. Value: '{0}'"; - public const string IDX10902 = "IDX10902: Object disposed exception in '{0}': '{1}'"; + public const string IDX10902 = "IDX10902: Exception caught while removing expired items: '{0}', Exception: '{1}'"; + public const string IDX10906 = "IDX10906: Exception caught while compacting items: '{0}', Exception: '{1}'"; // Crypto Errors public const string IDX11000 = "IDX11000: Cannot create EcdhKeyExchangeProvider. '{0}'\'s Curve '{1}' does not match with '{2}'\'s curve '{3}'."; diff --git a/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs index 1e80d5290c..fbc0fc727e 100644 --- a/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs @@ -68,6 +68,8 @@ public void Dispose() /// true, if called from Dispose(), false, if invoked inside a finalizer protected abstract void Dispose(bool disposing); + internal bool IsCached { get; set; } + /// /// Gets the . /// diff --git a/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs index e066b57f48..bb665cd9d0 100644 --- a/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs @@ -193,13 +193,11 @@ public override byte[] Sign(byte[] input) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } @@ -225,13 +223,11 @@ public override bool Sign(ReadOnlySpan input, Span signature, out in catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } #endif @@ -260,13 +256,11 @@ public override byte[] Sign(byte[] input, int offset, int count) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } @@ -301,9 +295,6 @@ public override bool Verify(byte[] input, byte[] signature) throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); } - if (LogHelper.IsEnabled(EventLogLevel.Informational)) - LogHelper.LogInformation(LogMessages.IDX10643, input); - KeyedHashAlgorithm keyedHashAlgorithm = GetKeyedHashAlgorithm(GetKeyBytes(Key), Algorithm); try { @@ -312,13 +303,11 @@ public override bool Verify(byte[] input, byte[] signature) catch { CryptoProviderCache?.TryRemove(this); - Dispose(true); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } @@ -449,9 +438,6 @@ internal bool Verify(byte[] input, int inputOffset, int inputLength, byte[] sign throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); } - if (LogHelper.IsEnabled(EventLogLevel.Informational)) - LogHelper.LogInformation(LogMessages.IDX10643, input); - KeyedHashAlgorithm keyedHashAlgorithm = null; try { @@ -465,18 +451,16 @@ internal bool Verify(byte[] input, int inputOffset, int inputLength, byte[] sign #else hash = keyedHashAlgorithm.ComputeHash(input, inputOffset, inputLength).AsSpan(); #endif - return Utility.AreEqual(signature, hash, signatureLength); } catch { - Dispose(true); + CryptoProviderCache?.TryRemove(this); throw; } finally { - if (!_disposed) - ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); } } diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs index ec42035a02..42fbe8c5cf 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs @@ -47,13 +47,13 @@ public void CreateAndReleaseSignatureProviders(SignatureProviderTheoryData theor { var disposeCalled = GetSignatureProviderIsDisposedByReflect(signatureProvider); if (!disposeCalled) - context.Diffs.Add("Dispose wasn't called on the AsymmetricSignatureProvider."); + context.Diffs.Add("Dispose was supposed to be called on the AsymmetricSignatureProvider."); } else // signatureProvider.GetType().Equals(typeof(SymmetricSignatureProvider)) { var disposeCalled = GetSignatureProviderIsDisposedByReflect(signatureProvider); if (!disposeCalled) - context.Diffs.Add("Dispose wasn't called on the SymmetricSignatureProvider."); + context.Diffs.Add("Dispose was supposed to be called on the SymmetricSignatureProvider."); } } catch (Exception ex) @@ -917,8 +917,8 @@ public void ReferenceCountingTest_Caching() cryptoProviderFactory.ReleaseSignatureProvider(signing); - if (!GetSignatureProviderIsDisposedByReflect(signing)) - context.AddDiff($"{nameof(signing2)} should have been disposed"); + if (GetSignatureProviderIsDisposedByReflect(signing)) + context.AddDiff($"{nameof(signing)} should not have been disposed"); TestUtilities.AssertFailIfErrors(context); }