diff --git a/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs b/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs
index 5faed428dab..f4df70c8389 100644
--- a/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs
+++ b/src/EFCore/Metadata/Conventions/Infrastructure/ProviderConventionSetBuilder.cs
@@ -104,6 +104,7 @@ public virtual ConventionSet CreateConventionSet()
var databaseGeneratedAttributeConvention = new DatabaseGeneratedAttributeConvention(Dependencies);
var requiredPropertyAttributeConvention = new RequiredPropertyAttributeConvention(Dependencies);
var nonNullableReferencePropertyConvention = new NonNullableReferencePropertyConvention(Dependencies);
+ var nonNullableNavigationConvention = new NonNullableNavigationConvention(Dependencies);
var maxLengthAttributeConvention = new MaxLengthAttributeConvention(Dependencies);
var stringLengthAttributeConvention = new StringLengthAttributeConvention(Dependencies);
var timestampAttributeConvention = new TimestampAttributeConvention(Dependencies);
@@ -171,11 +172,14 @@ public virtual ConventionSet CreateConventionSet()
conventionSet.ModelFinalizedConventions.Add(foreignKeyIndexConvention);
conventionSet.ModelFinalizedConventions.Add(foreignKeyPropertyDiscoveryConvention);
conventionSet.ModelFinalizedConventions.Add(servicePropertyDiscoveryConvention);
+ conventionSet.ModelFinalizedConventions.Add(nonNullableReferencePropertyConvention);
+ conventionSet.ModelFinalizedConventions.Add(nonNullableNavigationConvention);
conventionSet.ModelFinalizedConventions.Add(new ValidatingConvention(Dependencies));
+ // Don't add any more conventions to ModelFinalizedConventions after ValidatingConvention
conventionSet.NavigationAddedConventions.Add(backingFieldConvention);
conventionSet.NavigationAddedConventions.Add(new RequiredNavigationAttributeConvention(Dependencies));
- conventionSet.NavigationAddedConventions.Add(new NonNullableNavigationConvention(Dependencies));
+ conventionSet.NavigationAddedConventions.Add(nonNullableNavigationConvention);
conventionSet.NavigationAddedConventions.Add(inversePropertyAttributeConvention);
conventionSet.NavigationAddedConventions.Add(foreignKeyPropertyDiscoveryConvention);
conventionSet.NavigationAddedConventions.Add(relationshipDiscoveryConvention);
diff --git a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
index 832aca36cc1..b8910fa42c8 100644
--- a/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
+++ b/src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
@@ -2,9 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using JetBrains.Annotations;
+using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
@@ -13,11 +15,14 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
/// A base type for conventions that configure model aspects based on whether the member type
/// is a non-nullable reference type.
///
- public abstract class NonNullableConventionBase
+ public abstract class NonNullableConventionBase : IModelFinalizedConvention
{
+ // For the interpretation of nullability metadata, see
+ // https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-metadata.md
+
+ private const string StateAnnotationName = "NonNullableConventionState";
private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute";
- private Type _nullableAttrType;
- private FieldInfo _nullableFlagsFieldInfo;
+ private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute";
///
/// Creates a new instance of .
@@ -33,40 +38,118 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend
///
protected virtual ProviderConventionSetBuilderDependencies Dependencies { get; }
+ private byte? GetNullabilityContextFlag(NonNullabilityConventionState state, Attribute[] attributes)
+ {
+ if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute attribute)
+ {
+ var attributeType = attribute.GetType();
+
+ if (attributeType != state.NullableContextAttrType)
+ {
+ state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
+ state.NullableContextAttrType = attributeType;
+ }
+
+ if (state.NullableContextFlagFieldInfo?.GetValue(attribute) is byte flag)
+ {
+ return flag;
+ }
+ }
+
+ return null;
+ }
+
///
/// Returns a value indicating whether the member type is a non-nullable reference type.
///
+ /// The model builder used to build the model.
/// The member info.
/// true if the member type is a non-nullable reference type.
- protected virtual bool IsNonNullable([NotNull] MemberInfo memberInfo)
+ protected virtual bool IsNonNullable(
+ [NotNull] IConventionModelBuilder modelBuilder,
+ [NotNull] MemberInfo memberInfo)
{
+ var state = GetOrInitializeState(modelBuilder);
+
// For C# 8.0 nullable types, the C# currently synthesizes a NullableAttribute that expresses nullability into assemblies
// it produces. If the model is spread across more than one assembly, there will be multiple versions of this attribute,
// so look for it by name, caching to avoid reflection on every check.
// Note that this may change - if https://github.com/dotnet/corefx/issues/36222 is done we can remove all of this.
- if (!(Attribute.GetCustomAttributes(memberInfo, true)
- .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName)
- is { } attribute))
+
+ // First look for NullableAttribute on the member itself
+ if (Attribute.GetCustomAttributes(memberInfo, true)
+ .FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
+ {
+ var attributeType = attribute.GetType();
+
+ if (attributeType != state.NullableAttrType)
+ {
+ state.NullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
+ state.NullableAttrType = attributeType;
+ }
+
+ if (state.NullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags
+ && flags.FirstOrDefault() == 1)
+ {
+ return true;
+ }
+ }
+
+ // No attribute on the member, try to find a NullableContextAttribute on the declaring type
+ var type = memberInfo.DeclaringType;
+ if (type != null)
{
- return false;
+ if (state.TypeNonNullabilityContextCache.TryGetValue(type, out var cachedTypeNonNullable))
+ {
+ return cachedTypeNonNullable;
+ }
+
+ var typeContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(type));
+ if (typeContextFlag.HasValue)
+ {
+ return state.TypeNonNullabilityContextCache[type] = typeContextFlag.Value == 1;
+ }
}
- var attributeType = attribute.GetType();
- if (attributeType != _nullableAttrType)
+ // Not found at the type level, try at the module level
+ var module = memberInfo.Module;
+ if (!state.ModuleNonNullabilityContextCache.TryGetValue(module, out var moduleNonNullable))
{
- _nullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
- _nullableAttrType = attributeType;
+ var moduleContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(memberInfo.Module));
+ moduleNonNullable = state.ModuleNonNullabilityContextCache[module] =
+ moduleContextFlag.HasValue && moduleContextFlag == 1;
}
- // For the interpretation of NullableFlags, see
- // https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-reference-types.md#annotations
- if (_nullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags
- && flags.FirstOrDefault() == 1)
+ if (type != null)
{
- return true;
+ state.TypeNonNullabilityContextCache[type] = moduleNonNullable;
}
- return false;
+ return moduleNonNullable;
+ }
+
+ private NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder)
+ => (NonNullabilityConventionState)(
+ modelBuilder.Metadata.FindAnnotation(StateAnnotationName) ??
+ modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NonNullabilityConventionState())
+ ).Value;
+
+ ///
+ /// Called after a model is finalized. Removes the cached state annotation used by this convention.
+ ///
+ /// The builder for the model.
+ /// Additional information associated with convention execution.
+ public virtual void ProcessModelFinalized(IConventionModelBuilder modelBuilder, IConventionContext context)
+ => modelBuilder.Metadata.RemoveAnnotation(StateAnnotationName);
+
+ private class NonNullabilityConventionState
+ {
+ public Type NullableAttrType;
+ public Type NullableContextAttrType;
+ public FieldInfo NullableFlagsFieldInfo;
+ public FieldInfo NullableContextFlagFieldInfo;
+ public Dictionary TypeNonNullabilityContextCache { get; } = new Dictionary();
+ public Dictionary ModuleNonNullabilityContextCache { get; } = new Dictionary();
}
}
}
diff --git a/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs b/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs
index b3c90579715..7e104bbbb33 100644
--- a/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs
+++ b/src/EFCore/Metadata/Conventions/NonNullableNavigationConvention.cs
@@ -40,8 +40,9 @@ public virtual void ProcessNavigationAdded(
Check.NotNull(relationshipBuilder, nameof(relationshipBuilder));
Check.NotNull(navigation, nameof(navigation));
- if (!IsNonNullable(navigation)
- || navigation.IsCollection())
+ var modelBuilder = relationshipBuilder.ModelBuilder;
+
+ if (!IsNonNullable(modelBuilder, navigation) || navigation.IsCollection())
{
return;
}
@@ -51,7 +52,7 @@ public virtual void ProcessNavigationAdded(
var inverse = navigation.FindInverse();
if (inverse != null)
{
- if (IsNonNullable(inverse))
+ if (IsNonNullable(modelBuilder, inverse))
{
Dependencies.Logger.NonNullableReferenceOnBothNavigations(navigation, inverse);
return;
@@ -82,9 +83,9 @@ public virtual void ProcessNavigationAdded(
context.StopProcessingIfChanged(relationshipBuilder.Metadata.DependentToPrincipal);
}
- private bool IsNonNullable(IConventionNavigation navigation)
+ private bool IsNonNullable(IConventionModelBuilder modelBuilder, IConventionNavigation navigation)
=> navigation.DeclaringEntityType.HasClrType()
&& navigation.DeclaringEntityType.GetRuntimeProperties().Find(navigation.Name) is PropertyInfo propertyInfo
- && IsNonNullable(propertyInfo);
+ && IsNonNullable(modelBuilder, propertyInfo);
}
}
diff --git a/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs b/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs
index aa364b5d811..8fc712e8682 100644
--- a/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs
+++ b/src/EFCore/Metadata/Conventions/NonNullableReferencePropertyConvention.cs
@@ -29,7 +29,7 @@ private void Process(IConventionPropertyBuilder propertyBuilder)
// If the model is spread across multiple assemblies, it may contain different NullableAttribute types as
// the compiler synthesizes them for each assembly.
if (propertyBuilder.Metadata.GetIdentifyingMemberInfo() is MemberInfo memberInfo
- && IsNonNullable(memberInfo))
+ && IsNonNullable(propertyBuilder.ModelBuilder, memberInfo))
{
propertyBuilder.IsRequired(true);
}