From f29ed14e4d6f098d2d825a5e9c4b06c30497d751 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 1 Jul 2019 22:33:52 +0200 Subject: [PATCH] Make convention thread-safe --- .../Conventions/NonNullableConventionBase.cs | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs index 1bd8d000d9e..036589bd33e 100644 --- a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs +++ b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; +using System.Collections.Concurrent; using System.Linq; using System.Reflection; using JetBrains.Annotations; @@ -21,13 +21,15 @@ public abstract class NonNullableConventionBase private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute"; private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute"; + private readonly object _nullableAttrLock = new object(); + private readonly object _nullableContextAttrLock = new object(); private Type _nullableAttrType; private Type _nullableContextAttrType; private FieldInfo _nullableFlagsFieldInfo; private FieldInfo _nullableContextFlagFieldInfo; - private readonly Dictionary _typeNonNullabilityContextCache; - private readonly Dictionary _moduleNonNullabilityContextCache; + private readonly ConcurrentDictionary _typeNonNullabilityContextCache; + private readonly ConcurrentDictionary _moduleNonNullabilityContextCache; /// /// Creates a new instance of . @@ -37,8 +39,8 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend { Dependencies = dependencies; - _typeNonNullabilityContextCache = new Dictionary(); - _moduleNonNullabilityContextCache = new Dictionary(); + _typeNonNullabilityContextCache = new ConcurrentDictionary(); + _moduleNonNullabilityContextCache = new ConcurrentDictionary(); } /// @@ -48,18 +50,22 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend private byte? GetNullabilityContextFlag(Attribute[] attributes) { - if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is object attribute) + if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute attribute) { var attributeType = attribute.GetType(); - if (attributeType != _nullableContextAttrType) - { - _nullableContextFlagFieldInfo = attributeType.GetField("Flag"); - _nullableContextAttrType = attributeType; - } - if (_nullableContextFlagFieldInfo?.GetValue(attribute) is byte flag) + lock (_nullableContextAttrLock) { - return flag; + if (attributeType != _nullableContextAttrType) + { + _nullableContextFlagFieldInfo = attributeType.GetField("Flag"); + _nullableContextAttrType = attributeType; + } + + if (_nullableContextFlagFieldInfo?.GetValue(attribute) is byte flag) + { + return flag; + } } } @@ -80,19 +86,23 @@ protected virtual bool IsNonNullable([NotNull] MemberInfo memberInfo) // First look for NullableAttribute on the member itself if (Attribute.GetCustomAttributes(memberInfo, true) - .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is { } attribute) + .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute) { var attributeType = attribute.GetType(); - if (attributeType != _nullableAttrType) - { - _nullableFlagsFieldInfo = attributeType.GetField("NullableFlags"); - _nullableAttrType = attributeType; - } - if (_nullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags - && flags.FirstOrDefault() == 1) + lock (_nullableAttrLock) { - return true; + if (attributeType != _nullableAttrType) + { + _nullableFlagsFieldInfo = attributeType.GetField("NullableFlags"); + _nullableAttrType = attributeType; + } + + if (_nullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags + && flags.FirstOrDefault() == 1) + { + return true; + } } }