diff --git a/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs b/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs index 5faed428dab..f4df70c8389 100644 --- a/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs +++ b/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs @@ -104,6 +104,7 @@ public virtual ConventionSet CreateConventionSet() var databaseGeneratedAttributeConvention = new DatabaseGeneratedAttributeConvention(Dependencies); var requiredPropertyAttributeConvention = new RequiredPropertyAttributeConvention(Dependencies); var nonNullableReferencePropertyConvention = new NonNullableReferencePropertyConvention(Dependencies); + var nonNullableNavigationConvention = new NonNullableNavigationConvention(Dependencies); var maxLengthAttributeConvention = new MaxLengthAttributeConvention(Dependencies); var stringLengthAttributeConvention = new StringLengthAttributeConvention(Dependencies); var timestampAttributeConvention = new TimestampAttributeConvention(Dependencies); @@ -171,11 +172,14 @@ public virtual ConventionSet CreateConventionSet() conventionSet.ModelFinalizedConventions.Add(foreignKeyIndexConvention); conventionSet.ModelFinalizedConventions.Add(foreignKeyPropertyDiscoveryConvention); conventionSet.ModelFinalizedConventions.Add(servicePropertyDiscoveryConvention); + conventionSet.ModelFinalizedConventions.Add(nonNullableReferencePropertyConvention); + conventionSet.ModelFinalizedConventions.Add(nonNullableNavigationConvention); conventionSet.ModelFinalizedConventions.Add(new ValidatingConvention(Dependencies)); + // Don't add any more conventions to ModelFinalizedConventions after ValidatingConvention conventionSet.NavigationAddedConventions.Add(backingFieldConvention); conventionSet.NavigationAddedConventions.Add(new RequiredNavigationAttributeConvention(Dependencies)); - conventionSet.NavigationAddedConventions.Add(new NonNullableNavigationConvention(Dependencies)); + conventionSet.NavigationAddedConventions.Add(nonNullableNavigationConvention); conventionSet.NavigationAddedConventions.Add(inversePropertyAttributeConvention); conventionSet.NavigationAddedConventions.Add(foreignKeyPropertyDiscoveryConvention); conventionSet.NavigationAddedConventions.Add(relationshipDiscoveryConvention); diff --git a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs index 832aca36cc1..b8910fa42c8 100644 --- a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs +++ b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs @@ -2,9 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Metadata.Builders; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; namespace Microsoft.EntityFrameworkCore.Metadata.Conventions @@ -13,11 +15,14 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Conventions /// A base type for conventions that configure model aspects based on whether the member type /// is a non-nullable reference type. /// - public abstract class NonNullableConventionBase + public abstract class NonNullableConventionBase : IModelFinalizedConvention { + // For the interpretation of nullability metadata, see + // https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-metadata.md + + private const string StateAnnotationName = "NonNullableConventionState"; private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute"; - private Type _nullableAttrType; - private FieldInfo _nullableFlagsFieldInfo; + private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute"; /// /// Creates a new instance of . @@ -33,40 +38,118 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend /// protected virtual ProviderConventionSetBuilderDependencies Dependencies { get; } + private byte? GetNullabilityContextFlag(NonNullabilityConventionState state, Attribute[] attributes) + { + if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute attribute) + { + var attributeType = attribute.GetType(); + + if (attributeType != state.NullableContextAttrType) + { + state.NullableContextFlagFieldInfo = attributeType.GetField("Flag"); + state.NullableContextAttrType = attributeType; + } + + if (state.NullableContextFlagFieldInfo?.GetValue(attribute) is byte flag) + { + return flag; + } + } + + return null; + } + /// /// Returns a value indicating whether the member type is a non-nullable reference type. /// + /// The model builder used to build the model. /// The member info. /// true if the member type is a non-nullable reference type. - protected virtual bool IsNonNullable([NotNull] MemberInfo memberInfo) + protected virtual bool IsNonNullable( + [NotNull] IConventionModelBuilder modelBuilder, + [NotNull] MemberInfo memberInfo) { + var state = GetOrInitializeState(modelBuilder); + // For C# 8.0 nullable types, the C# currently synthesizes a NullableAttribute that expresses nullability into assemblies // it produces. If the model is spread across more than one assembly, there will be multiple versions of this attribute, // so look for it by name, caching to avoid reflection on every check. // Note that this may change - if https://github.com/dotnet/corefx/issues/36222 is done we can remove all of this. - if (!(Attribute.GetCustomAttributes(memberInfo, true) - .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) - is { } attribute)) + + // First look for NullableAttribute on the member itself + if (Attribute.GetCustomAttributes(memberInfo, true) + .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute) + { + var attributeType = attribute.GetType(); + + if (attributeType != state.NullableAttrType) + { + state.NullableFlagsFieldInfo = attributeType.GetField("NullableFlags"); + state.NullableAttrType = attributeType; + } + + if (state.NullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags + && flags.FirstOrDefault() == 1) + { + return true; + } + } + + // No attribute on the member, try to find a NullableContextAttribute on the declaring type + var type = memberInfo.DeclaringType; + if (type != null) { - return false; + if (state.TypeNonNullabilityContextCache.TryGetValue(type, out var cachedTypeNonNullable)) + { + return cachedTypeNonNullable; + } + + var typeContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(type)); + if (typeContextFlag.HasValue) + { + return state.TypeNonNullabilityContextCache[type] = typeContextFlag.Value == 1; + } } - var attributeType = attribute.GetType(); - if (attributeType != _nullableAttrType) + // Not found at the type level, try at the module level + var module = memberInfo.Module; + if (!state.ModuleNonNullabilityContextCache.TryGetValue(module, out var moduleNonNullable)) { - _nullableFlagsFieldInfo = attributeType.GetField("NullableFlags"); - _nullableAttrType = attributeType; + var moduleContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(memberInfo.Module)); + moduleNonNullable = state.ModuleNonNullabilityContextCache[module] = + moduleContextFlag.HasValue && moduleContextFlag == 1; } - // For the interpretation of NullableFlags, see - // https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-reference-types.md#annotations - if (_nullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags - && flags.FirstOrDefault() == 1) + if (type != null) { - return true; + state.TypeNonNullabilityContextCache[type] = moduleNonNullable; } - return false; + return moduleNonNullable; + } + + private NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder) + => (NonNullabilityConventionState)( + modelBuilder.Metadata.FindAnnotation(StateAnnotationName) ?? + modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NonNullabilityConventionState()) + ).Value; + + /// + /// Called after a model is finalized. Removes the cached state annotation used by this convention. + /// + /// The builder for the model. + /// Additional information associated with convention execution. + public virtual void ProcessModelFinalized(IConventionModelBuilder modelBuilder, IConventionContext context) + => modelBuilder.Metadata.RemoveAnnotation(StateAnnotationName); + + private class NonNullabilityConventionState + { + public Type NullableAttrType; + public Type NullableContextAttrType; + public FieldInfo NullableFlagsFieldInfo; + public FieldInfo NullableContextFlagFieldInfo; + public Dictionary TypeNonNullabilityContextCache { get; } = new Dictionary(); + public Dictionary ModuleNonNullabilityContextCache { get; } = new Dictionary(); } } } diff --git a/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs b/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs index b3c90579715..7e104bbbb33 100644 --- a/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs +++ b/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs @@ -40,8 +40,9 @@ public virtual void ProcessNavigationAdded( Check.NotNull(relationshipBuilder, nameof(relationshipBuilder)); Check.NotNull(navigation, nameof(navigation)); - if (!IsNonNullable(navigation) - || navigation.IsCollection()) + var modelBuilder = relationshipBuilder.ModelBuilder; + + if (!IsNonNullable(modelBuilder, navigation) || navigation.IsCollection()) { return; } @@ -51,7 +52,7 @@ public virtual void ProcessNavigationAdded( var inverse = navigation.FindInverse(); if (inverse != null) { - if (IsNonNullable(inverse)) + if (IsNonNullable(modelBuilder, inverse)) { Dependencies.Logger.NonNullableReferenceOnBothNavigations(navigation, inverse); return; @@ -82,9 +83,9 @@ public virtual void ProcessNavigationAdded( context.StopProcessingIfChanged(relationshipBuilder.Metadata.DependentToPrincipal); } - private bool IsNonNullable(IConventionNavigation navigation) + private bool IsNonNullable(IConventionModelBuilder modelBuilder, IConventionNavigation navigation) => navigation.DeclaringEntityType.HasClrType() && navigation.DeclaringEntityType.GetRuntimeProperties().Find(navigation.Name) is PropertyInfo propertyInfo - && IsNonNullable(propertyInfo); + && IsNonNullable(modelBuilder, propertyInfo); } } diff --git a/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs b/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs index aa364b5d811..8fc712e8682 100644 --- a/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs +++ b/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs @@ -29,7 +29,7 @@ private void Process(IConventionPropertyBuilder propertyBuilder) // If the model is spread across multiple assemblies, it may contain different NullableAttribute types as // the compiler synthesizes them for each assembly. if (propertyBuilder.Metadata.GetIdentifyingMemberInfo() is MemberInfo memberInfo - && IsNonNullable(memberInfo)) + && IsNonNullable(propertyBuilder.ModelBuilder, memberInfo)) { propertyBuilder.IsRequired(true); }