diff --git a/CommunityToolkit.Mvvm/Messaging/Internals/System/Collections.Generic/Dictionary2.cs b/CommunityToolkit.Mvvm/Messaging/Internals/System/Collections.Generic/Dictionary2.cs index 3f1719975..5e5916716 100644 --- a/CommunityToolkit.Mvvm/Messaging/Internals/System/Collections.Generic/Dictionary2.cs +++ b/CommunityToolkit.Mvvm/Messaging/Internals/System/Collections.Generic/Dictionary2.cs @@ -321,19 +321,19 @@ public bool MoveNext() /// /// Gets the current key. /// - public TKey Key + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey() { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => this.entries[this.index - 1].Key; + return this.entries[this.index - 1].Key; } /// /// Gets the current value. /// - public TValue Value + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TValue GetValue() { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => this.entries[this.index - 1].Value!; + return this.entries[this.index - 1].Value!; } } diff --git a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.Proxy.cs b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.Proxy.cs new file mode 100644 index 000000000..95970a139 --- /dev/null +++ b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.Proxy.cs @@ -0,0 +1,108 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#if NETSTANDARD2_1 + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; + +namespace System.Runtime.CompilerServices; + +/// +/// A wrapper for with a custom enumerator. +/// +/// Tke key of items to store in the table. +/// The values to store in the table. +internal sealed class ConditionalWeakTable2 + where TKey : class + where TValue : class? +{ + /// + /// The underlying instance. + /// + private readonly ConditionalWeakTable table = new(); + + /// + public bool TryGetValue(TKey key, [NotNullWhen(true)] out TValue? value) + { + return this.table.TryGetValue(key, out value); + } + + /// + public bool TryAdd(TKey key, TValue value) + { + return this.table.TryAdd(key, value); + } + + /// + public TValue GetValue(TKey key, ConditionalWeakTable.CreateValueCallback createValueCallback) + { + return this.table.GetValue(key, createValueCallback); + } + + /// + public bool Remove(TKey key) + { + return this.table.Remove(key); + } + + /// + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Enumerator GetEnumerator() => new(this); + + /// + /// A custom enumerator that traverses items in a instance. + /// + public ref struct Enumerator + { + /// + /// The wrapped instance for the enumerator. + /// + private readonly IEnumerator> enumerator; + + /// + /// Initializes a new instance of the struct. + /// + /// The owner instance for the enumerator. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Enumerator(ConditionalWeakTable2 owner) + { + this.enumerator = ((IEnumerable>)owner.table).GetEnumerator(); + } + + /// + public void Dispose() + { + this.enumerator.Dispose(); + } + + /// + public bool MoveNext() + { + return this.enumerator.MoveNext(); + } + + /// + /// Gets the current key. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey() + { + return this.enumerator.Current.Key; + } + + /// + /// Gets the current value. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TValue GetValue() + { + return this.enumerator.Current.Value; + } + } +} + +#endif \ No newline at end of file diff --git a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.ZeroAlloc.cs b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.ZeroAlloc.cs index e2a7f686a..698a9db5f 100644 --- a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.ZeroAlloc.cs +++ b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.ZeroAlloc.cs @@ -157,9 +157,14 @@ public ref struct Enumerator private int currentIndex; /// - /// The current entry set by MoveNext and returned from . + /// The current key, if available. /// - private KeyValuePair current; + private TKey? key; + + /// + /// The current value, if available. + /// + private TValue? value; /// /// Initializes a new instance of the class. @@ -170,22 +175,22 @@ public Enumerator(ConditionalWeakTable2 table) // Store a reference to the parent table and increase its active enumerator count this.table = table; - Container c = table.container; + Container container = table.container; - if (c is null || c.FirstFreeEntry == 0) + if (container is null || container.FirstFreeEntry == 0) { // The max index is the same as the current to prevent enumeration this.maxIndexInclusive = -1; - this.currentIndex = -1; - this.current = default; } else { // Store the max index to be enumerated - this.maxIndexInclusive = table.container.FirstFreeEntry - 1; - this.currentIndex = -1; - this.current = default; + this.maxIndexInclusive = container.FirstFreeEntry - 1; } + + this.currentIndex = -1; + this.key = null; + this.value = null; } /// @@ -197,7 +202,8 @@ public void Dispose() this.table = null!; // Ensure we don't keep the last current alive unnecessarily - this.current = default; + this.key = null; + this.value = null; } /// @@ -209,30 +215,44 @@ public bool MoveNext() // container at the time) has already been finalized, this will be null. Container c = this.table.container; - if (c != null) + int currentIndex = this.currentIndex; + int maxIndexInclusive = this.maxIndexInclusive; + + // We have the container. Find the next entry to return, if there is one. We need to loop as we + // may try to get an entry that's already been removed or collected, in which case we try again. + while (currentIndex < maxIndexInclusive) { - // We have the container. Find the next entry to return, if there is one. We need to loop as we - // may try to get an entry that's already been removed or collected, in which case we try again. - while (this.currentIndex < this.maxIndexInclusive) - { - this.currentIndex++; + currentIndex++; - if (c.TryGetEntry(this.currentIndex, out TKey? key, out TValue? value)) - { - this.current = new KeyValuePair(key, value); + if (c.TryGetEntry(currentIndex, out this.key, out this.value)) + { + this.currentIndex = currentIndex; - return true; - } + return true; } } + this.currentIndex = currentIndex; + return false; } - public KeyValuePair Current + /// + /// Gets the current key. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey() + { + return this.key!; + } + + /// + /// Gets the current value. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TValue GetValue() { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => this.current; + return this.value!; } } diff --git a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.cs b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.cs index 7e1dd16a0..7b3904498 100644 --- a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.cs +++ b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTable2{TKey,TValue}.cs @@ -37,11 +37,7 @@ public bool TryGetValue(TKey key, [NotNullWhen(true)] out TValue? value) return this.table.TryGetValue(key, out value); } - /// - /// Tries to add a new pair to the table. - /// - /// The key to add. - /// The value to associate with key. + /// public bool TryAdd(TKey key, TValue value) { if (!this.table.TryAdd(key, value)) @@ -116,9 +112,14 @@ public ref struct Enumerator private LinkedListNode>? node; /// - /// The current to return. + /// The current key, if available. + /// + private TKey? key; + + /// + /// The current value, if available. /// - private KeyValuePair current; + private TValue? value; /// /// Indicates whether or not has been called at least once. @@ -134,10 +135,16 @@ public Enumerator(ConditionalWeakTable2 owner) { this.owner = owner; this.node = null; - this.current = default; + this.key = null; + this.value = null; this.isFirstMoveNextPending = true; } + /// + public void Dispose() + { + } + /// public bool MoveNext() { @@ -163,7 +170,8 @@ public bool MoveNext() this.owner.table.TryGetValue(target!, out TValue? value)) { this.node = node; - this.current = new KeyValuePair(target, value); + this.key = target; + this.value = value; return true; } @@ -179,11 +187,22 @@ public bool MoveNext() return false; } - /// - public readonly KeyValuePair Current + /// + /// Gets the current key. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey() + { + return this.key!; + } + + /// + /// Gets the current value. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TValue GetValue() { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => this.current; + return this.value!; } } } diff --git a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTableExtensions.cs b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTableExtensions.cs index ab426fd15..adbe1d076 100644 --- a/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTableExtensions.cs +++ b/CommunityToolkit.Mvvm/Messaging/Internals/System/Runtime.CompilerServices/ConditionalWeakTableExtensions.cs @@ -23,7 +23,7 @@ public static bool TryAdd(this ConditionalWeakTable where TKey : class where TValue : class? { - // There is no way to do this on .NET Standard 2.0 without exception handling + // There is no way to do this on .NET Standard 2.0 or 2.1 without exception handling try { table.Add(key, value); diff --git a/CommunityToolkit.Mvvm/Messaging/StrongReferenceMessenger.cs b/CommunityToolkit.Mvvm/Messaging/StrongReferenceMessenger.cs index 1cc581a24..b81fd5ee1 100644 --- a/CommunityToolkit.Mvvm/Messaging/StrongReferenceMessenger.cs +++ b/CommunityToolkit.Mvvm/Messaging/StrongReferenceMessenger.cs @@ -392,7 +392,7 @@ public TMessage Send(TMessage message, TToken token) while (mappingEnumerator.MoveNext()) { // Pick the target handler, if the token is a match for the recipient - if (mappingEnumerator.Value.TryGetValue(token, out object? handler)) + if (mappingEnumerator.GetValue().TryGetValue(token, out object? handler)) { // This span access should always guaranteed to be valid due to the size of the // array being set according to the current total number of registered handlers, @@ -400,7 +400,7 @@ public TMessage Send(TMessage message, TToken token) // We're still using a checked span accesses here though to make sure an out of // bounds write can never happen even if an error was present in the logic above. pairs[2 * i] = handler; - pairs[(2 * i) + 1] = mappingEnumerator.Key.Target; + pairs[(2 * i) + 1] = mappingEnumerator.GetKey().Target; i++; } } diff --git a/CommunityToolkit.Mvvm/Messaging/WeakReferenceMessenger.cs b/CommunityToolkit.Mvvm/Messaging/WeakReferenceMessenger.cs index afc11f977..1825aa430 100644 --- a/CommunityToolkit.Mvvm/Messaging/WeakReferenceMessenger.cs +++ b/CommunityToolkit.Mvvm/Messaging/WeakReferenceMessenger.cs @@ -8,11 +8,7 @@ using System.Runtime.InteropServices; using System.Threading; using CommunityToolkit.Mvvm.Messaging.Internals; -#if NETSTANDARD2_0 || NET6_0_OR_GREATER using RecipientsTable = System.Runtime.CompilerServices.ConditionalWeakTable2; -#else -using RecipientsTable = System.Runtime.CompilerServices.ConditionalWeakTable; -#endif namespace CommunityToolkit.Mvvm.Messaging; @@ -204,7 +200,7 @@ public void UnregisterAll(object recipient) // as that is responsibility of a separate method defined below. while (enumerator.MoveNext()) { - _ = enumerator.Value.Remove(recipient); + _ = enumerator.GetValue().Remove(recipient); } } } @@ -222,13 +218,13 @@ public void UnregisterAll(object recipient, TToken token) // only try to remove handlers with a matching token, if any. while (enumerator.MoveNext()) { - if (enumerator.Key.TToken == typeof(TToken)) + if (enumerator.GetKey().TToken == typeof(TToken)) { if (typeof(TToken) == typeof(Unit)) { - _ = enumerator.Value.Remove(recipient); + _ = enumerator.GetValue().Remove(recipient); } - else if (enumerator.Value.TryGetValue(recipient, out object? mapping)) + else if (enumerator.GetValue().TryGetValue(recipient, out object? mapping)) { _ = Unsafe.As>(mapping).TryRemove(token); } @@ -288,22 +284,24 @@ public TMessage Send(TMessage message, TToken token) // to enumerate all the existing recipients for the token and message types pair // corresponding to the generic arguments for this invocation, and then track the // handlers with a matching token, and their corresponding recipients. - foreach (KeyValuePair pair in table) + using RecipientsTable.Enumerator enumerator = table.GetEnumerator(); + + while (enumerator.MoveNext()) { if (typeof(TToken) == typeof(Unit)) { - bufferWriter.Add(pair.Value); - bufferWriter.Add(pair.Key); + bufferWriter.Add(enumerator.GetValue()); + bufferWriter.Add(enumerator.GetKey()); i++; } else { - Dictionary2? map = Unsafe.As>(pair.Value); + Dictionary2? map = Unsafe.As>(enumerator.GetValue()); if (map.TryGetValue(token, out object? handler)) { bufferWriter.Add(handler); - bufferWriter.Add(pair.Key); + bufferWriter.Add(enumerator.GetKey()); i++; } } @@ -441,23 +439,25 @@ private void CleanupWithoutLock() using ArrayPoolBufferWriter type2s = ArrayPoolBufferWriter.Create(); using ArrayPoolBufferWriter emptyRecipients = ArrayPoolBufferWriter.Create(); - Dictionary2.Enumerator enumerator = this.recipientsMap.GetEnumerator(); + Dictionary2.Enumerator type2Enumerator = this.recipientsMap.GetEnumerator(); // First, we go through all the currently registered pairs of token and message types. // These represents all the combinations of generic arguments with at least one registered // handler, with the exception of those with recipients that have already been collected. - while (enumerator.MoveNext()) + while (type2Enumerator.MoveNext()) { emptyRecipients.Reset(); bool hasAtLeastOneHandler = false; - if (enumerator.Key.TToken == typeof(Unit)) + if (type2Enumerator.GetKey().TToken == typeof(Unit)) { // When the token type is unit, there can be no registered recipients with no handlers, // as when the single handler is unsubscribed the recipient is also removed immediately. // Therefore, we need to check that there exists at least one recipient for the message. - foreach (KeyValuePair pair in enumerator.Value) + using RecipientsTable.Enumerator recipientsEnumerator = type2Enumerator.GetValue().GetEnumerator(); + + while (recipientsEnumerator.MoveNext()) { hasAtLeastOneHandler = true; @@ -468,29 +468,32 @@ private void CleanupWithoutLock() { // Go through the currently alive recipients to look for those with no handlers left. We track // the ones we find to remove them outside of the loop (can't modify during enumeration). - foreach (KeyValuePair pair in enumerator.Value) + using (RecipientsTable.Enumerator recipientsEnumerator = type2Enumerator.GetValue().GetEnumerator()) { - if (Unsafe.As(pair.Value).Count == 0) - { - emptyRecipients.Add(pair.Key); - } - else + while (recipientsEnumerator.MoveNext()) { - hasAtLeastOneHandler = true; + if (Unsafe.As(recipientsEnumerator.GetValue()).Count == 0) + { + emptyRecipients.Add(recipientsEnumerator.GetKey()); + } + else + { + hasAtLeastOneHandler = true; + } } } // Remove the handler maps for recipients that are still alive but with no handlers foreach (object recipient in emptyRecipients.Span) { - _ = enumerator.Value.Remove(recipient); + _ = type2Enumerator.GetValue().Remove(recipient); } } // Track the type combinations with no recipients or handlers left if (!hasAtLeastOneHandler) { - type2s.Add(enumerator.Key); + type2s.Add(type2Enumerator.GetKey()); } } diff --git a/tests/CommunityToolkit.Mvvm.UnitTests/Internals/Test_ConditionalWeakTable2.cs b/tests/CommunityToolkit.Mvvm.UnitTests/Internals/Test_ConditionalWeakTable2.cs index 6a1ec42f8..9603bd936 100644 --- a/tests/CommunityToolkit.Mvvm.UnitTests/Internals/Test_ConditionalWeakTable2.cs +++ b/tests/CommunityToolkit.Mvvm.UnitTests/Internals/Test_ConditionalWeakTable2.cs @@ -362,12 +362,15 @@ public void RemoveAll_AddMany_RemoveAll_AllItemsRemoved() int count = 0; - foreach (KeyValuePair pair in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - Assert.AreSame(pair.Key, keys[count]); - Assert.AreSame(pair.Value, values[count]); + while (enumerator.MoveNext()) + { + Assert.AreSame(enumerator.GetKey(), keys[count]); + Assert.AreSame(enumerator.GetValue(), values[count]); - count++; + count++; + } } Assert.AreEqual(keys.Length, count); @@ -377,9 +380,12 @@ public void RemoveAll_AddMany_RemoveAll_AllItemsRemoved() Assert.IsTrue(cwt.Remove(key)); } - foreach (KeyValuePair pair in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - Assert.Fail(); + while (enumerator.MoveNext()) + { + Assert.Fail(); + } } GC.KeepAlive(keys); @@ -391,7 +397,9 @@ public void GetEnumerator_Empty_ReturnsEmptyEnumerator() { ConditionalWeakTable2 cwt = new(); - foreach (KeyValuePair _ in cwt) + using ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator(); + + while (enumerator.MoveNext()) { Assert.Fail(); } @@ -411,9 +419,12 @@ public void GetEnumerator_AddedAndRemovedItems_AppropriatelyShowUpInEnumeration( int count = 0; - foreach (KeyValuePair pair in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - count++; + while (enumerator.MoveNext()) + { + count++; + } } Assert.AreEqual(1, count); @@ -422,21 +433,24 @@ public void GetEnumerator_AddedAndRemovedItems_AppropriatelyShowUpInEnumeration( KeyValuePair? first = null; - foreach (KeyValuePair pair in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - if (first is not null) + while (enumerator.MoveNext()) { - Assert.Fail(); - } + if (first is not null) + { + Assert.Fail(); + } - first = pair; + first = new KeyValuePair(enumerator.GetKey(), enumerator.GetValue()); - if (count > 0) - { - Assert.Fail(); - } + if (count > 0) + { + Assert.Fail(); + } - count++; + count++; + } } Assert.AreEqual(new KeyValuePair(key1, value1), first); @@ -445,9 +459,12 @@ public void GetEnumerator_AddedAndRemovedItems_AppropriatelyShowUpInEnumeration( count = 0; - foreach (KeyValuePair pair in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - count++; + while (enumerator.MoveNext()) + { + count++; + } } Assert.AreEqual(0, count); @@ -462,7 +479,7 @@ public void GetEnumerator_CollectedItemsNotEnumerated() { ConditionalWeakTable2 cwt = new(); - using ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator(); + using ConditionalWeakTable2.Enumerator enumerator1 = cwt.GetEnumerator(); static void addItem(ConditionalWeakTable2 t) => t.GetValue(new object(), _ => new object()); @@ -475,9 +492,12 @@ public void GetEnumerator_CollectedItemsNotEnumerated() int count = 0; - foreach (KeyValuePair _ in cwt) + using (ConditionalWeakTable2.Enumerator enumerator2 = cwt.GetEnumerator()) { - count++; + while (enumerator2.MoveNext()) + { + count++; + } } Assert.AreEqual(0, count); @@ -502,7 +522,8 @@ public void GetEnumerator_MultipleEnumeratorsReturnSameResults() while (enumerator1.MoveNext()) { Assert.IsTrue(enumerator2.MoveNext()); - Assert.AreEqual(enumerator1.Current, enumerator2.Current); + Assert.AreEqual(enumerator1.GetKey(), enumerator2.GetKey()); + Assert.AreEqual(enumerator1.GetValue(), enumerator2.GetValue()); } Assert.IsFalse(enumerator2.MoveNext()); @@ -531,18 +552,24 @@ public void GetEnumerator_RemovedItems_RemovedFromResults() { count = 0; - foreach (KeyValuePair _ in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - count++; + while (enumerator.MoveNext()) + { + count++; + } } Assert.AreEqual(keys.Length - i, count); List> pairs = new(); - foreach (KeyValuePair pair in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - pairs.Add(pair); + while (enumerator.MoveNext()) + { + pairs.Add(new KeyValuePair(enumerator.GetKey(), enumerator.GetValue())); + } } CollectionAssert.AreEqual( @@ -554,9 +581,12 @@ public void GetEnumerator_RemovedItems_RemovedFromResults() count = 0; - foreach (KeyValuePair _ in cwt) + using (ConditionalWeakTable2.Enumerator enumerator = cwt.GetEnumerator()) { - count++; + while (enumerator.MoveNext()) + { + count++; + } } Assert.AreEqual(0, count);