From 644631a783cd196a6d193761646bc9dc3be4e9ba Mon Sep 17 00:00:00 2001 From: Matt Kotsenas Date: Fri, 25 Oct 2024 15:58:20 -0700 Subject: [PATCH] Avoid closure capture in ConcurrentDictionary.GetOrAdd`3 (#242) * Rewrite ConcurrentDictionary.GetOrAdd`3 to avoid closure allocation * Use static modifier on existing ConcurrentDictionary.GetOrAdd`3 calls to prevent accidental capture --- api_list.include.md | 4 +- src/Consume/Consume.cs | 2 +- .../Nullability/NullabilityInfoExtensions.cs | 8 ++-- src/Polyfill/Polyfill_ConcurrentDictionary.cs | 42 +++++++++++++++++-- .../PolyfillTests_ConcurrentDictionary.cs | 2 +- 5 files changed, 46 insertions(+), 12 deletions(-) diff --git a/api_list.include.md b/api_list.include.md index e571ddbf..561a9fd7 100644 --- a/api_list.include.md +++ b/api_list.include.md @@ -22,9 +22,9 @@ * `Task CancelAsync(CancellationTokenSource)` [reference](https://learn.microsoft.com/en-us/dotnet/api/system.threading.cancellationtokensource.cancelasync) -#### ConcurrentDictionary +#### ConcurrentDictionary - * `TValue GetOrAdd(ConcurrentDictionary, TKey, Func, TArg) where TKey : notnull` [reference](https://learn.microsoft.com/en-us/dotnet/api/system.collections.concurrent.concurrentdictionary-2.getoradd#system-collections-concurrent-concurrentdictionary-2-getoradd-1(-0-system-func((-0-0-1))-0)) + * `TValue GetOrAdd(ConcurrentDictionary, TKey, Func, TArg) where TKey : notnull` [reference](https://learn.microsoft.com/en-us/dotnet/api/system.collections.concurrent.concurrentdictionary-2.getoradd#system-collections-concurrent-concurrentdictionary-2-getoradd-1(-0-system-func((-0-0-1))-0)) #### DateOnly diff --git a/src/Consume/Consume.cs b/src/Consume/Consume.cs index fb059b63..54a52d4c 100644 --- a/src/Consume/Consume.cs +++ b/src/Consume/Consume.cs @@ -218,7 +218,7 @@ void Type_GetMethod() void ConcurrentDictionary_Methods() { var dict = new ConcurrentDictionary(); - var value = dict.GetOrAdd("Hello", (_, arg) => arg.Length, "World"); + var value = dict.GetOrAdd("Hello", static (_, arg) => arg.Length, "World"); } void Dictionary_Methods() diff --git a/src/Polyfill/Nullability/NullabilityInfoExtensions.cs b/src/Polyfill/Nullability/NullabilityInfoExtensions.cs index e8a720fd..5f96427e 100644 --- a/src/Polyfill/Nullability/NullabilityInfoExtensions.cs +++ b/src/Polyfill/Nullability/NullabilityInfoExtensions.cs @@ -52,7 +52,7 @@ public static bool IsNullable(this MemberInfo info) } public static NullabilityInfo GetNullabilityInfo(this FieldInfo info) => - fieldCache.GetOrAdd(info, inner => + fieldCache.GetOrAdd(info, static inner => { var context = new NullabilityInfoContext(); return context.Create(inner); @@ -68,7 +68,7 @@ public static bool IsNullable(this FieldInfo info) } public static NullabilityInfo GetNullabilityInfo(this EventInfo info) => - eventCache.GetOrAdd(info, inner => + eventCache.GetOrAdd(info, static inner => { var context = new NullabilityInfoContext(); return context.Create(inner); @@ -86,7 +86,7 @@ public static bool IsNullable(this EventInfo info) public static NullabilityInfo GetNullabilityInfo(this PropertyInfo info) => propertyCache.GetOrAdd( info, - inner => + static inner => { var context = new NullabilityInfoContext(); return context.Create(inner); @@ -102,7 +102,7 @@ public static bool IsNullable(this PropertyInfo info) } public static NullabilityInfo GetNullabilityInfo(this ParameterInfo info) => - parameterCache.GetOrAdd(info, inner => + parameterCache.GetOrAdd(info, static inner => { var context = new NullabilityInfoContext(); return context.Create(inner); diff --git a/src/Polyfill/Polyfill_ConcurrentDictionary.cs b/src/Polyfill/Polyfill_ConcurrentDictionary.cs index 4a3b7131..4c6c7008 100644 --- a/src/Polyfill/Polyfill_ConcurrentDictionary.cs +++ b/src/Polyfill/Polyfill_ConcurrentDictionary.cs @@ -25,12 +25,46 @@ static partial class Polyfill /// (Nothing in Visual Basic). /// The dictionary contains too many /// elements. - /// The value for the key. This will be either the existing value for the key if the + /// The value for the key. This will be either the existing value for the key if the /// key is already in the dictionary, or the new value for the key as returned by valueFactory /// if the key was not in the dictionary. [Link("https://learn.microsoft.com/en-us/dotnet/api/system.collections.concurrent.concurrentdictionary-2.getoradd#system-collections-concurrent-concurrentdictionary-2-getoradd-1(-0-system-func((-0-0-1))-0)")] - public static TValue GetOrAdd(this ConcurrentDictionary target, TKey key, Func valueFactory, TArg factoryArgument) - where TKey : notnull => - target.GetOrAdd(key, _ => valueFactory(_, factoryArgument)); + public static TValue GetOrAdd(this ConcurrentDictionary target, TKey key, Func valueFactory, TArg factoryArgument) + where TKey : notnull + { + // Implementation based on https://github.com/dotnet/runtime/issues/13978#issuecomment-69494764. + // Because this API is intended to be used in high performance scenarios where avoiding allocations + // is important, we can't delegate to the existing `GetOrAdd`2`, as that would allocate a closure + // over `factoryArgument`. + + if (target is null) + { + throw new ArgumentNullException(nameof(target)); + } + + if (key is null) + { + throw new ArgumentNullException(nameof(target)); + } + if (valueFactory is null) + { + throw new ArgumentNullException(nameof(valueFactory)); + } + + while (true) + { + TValue value; + if (target.TryGetValue(key, out value)) + { + return value; + } + + value = valueFactory(key, factoryArgument); + if (target.TryAdd(key, value)) + { + return value; + } + } + } } #endif \ No newline at end of file diff --git a/src/Tests/PolyfillTests_ConcurrentDictionary.cs b/src/Tests/PolyfillTests_ConcurrentDictionary.cs index 4cda3be7..bc9f5af9 100644 --- a/src/Tests/PolyfillTests_ConcurrentDictionary.cs +++ b/src/Tests/PolyfillTests_ConcurrentDictionary.cs @@ -5,7 +5,7 @@ public void ConcurrentDictionaryGetOrAddFunc() { var dictionary = new ConcurrentDictionary(); - Func valueFactory = (key, arg) => arg.Length; + Func valueFactory = static (key, arg) => arg.Length; var value = dictionary.GetOrAdd("Hello", valueFactory, "World");