Skip to content

Commit

Permalink
Update nullability convention to new nullability metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
roji authored and dougbu committed Jul 2, 2019
1 parent 234d36a commit 0c4d97e
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public virtual ConventionSet CreateConventionSet()
var databaseGeneratedAttributeConvention = new DatabaseGeneratedAttributeConvention(Dependencies);
var requiredPropertyAttributeConvention = new RequiredPropertyAttributeConvention(Dependencies);
var nonNullableReferencePropertyConvention = new NonNullableReferencePropertyConvention(Dependencies);
var nonNullableNavigationConvention = new NonNullableNavigationConvention(Dependencies);
var maxLengthAttributeConvention = new MaxLengthAttributeConvention(Dependencies);
var stringLengthAttributeConvention = new StringLengthAttributeConvention(Dependencies);
var timestampAttributeConvention = new TimestampAttributeConvention(Dependencies);
Expand Down Expand Up @@ -171,11 +172,14 @@ public virtual ConventionSet CreateConventionSet()
conventionSet.ModelFinalizedConventions.Add(foreignKeyIndexConvention);
conventionSet.ModelFinalizedConventions.Add(foreignKeyPropertyDiscoveryConvention);
conventionSet.ModelFinalizedConventions.Add(servicePropertyDiscoveryConvention);
conventionSet.ModelFinalizedConventions.Add(nonNullableReferencePropertyConvention);
conventionSet.ModelFinalizedConventions.Add(nonNullableNavigationConvention);
conventionSet.ModelFinalizedConventions.Add(new ValidatingConvention(Dependencies));
// Don't add any more conventions to ModelFinalizedConventions after ValidatingConvention

conventionSet.NavigationAddedConventions.Add(backingFieldConvention);
conventionSet.NavigationAddedConventions.Add(new RequiredNavigationAttributeConvention(Dependencies));
conventionSet.NavigationAddedConventions.Add(new NonNullableNavigationConvention(Dependencies));
conventionSet.NavigationAddedConventions.Add(nonNullableNavigationConvention);
conventionSet.NavigationAddedConventions.Add(inversePropertyAttributeConvention);
conventionSet.NavigationAddedConventions.Add(foreignKeyPropertyDiscoveryConvention);
conventionSet.NavigationAddedConventions.Add(relationshipDiscoveryConvention);
Expand Down
119 changes: 101 additions & 18 deletions src/EFCore/Metadata/Conventions/NonNullableConventionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
Expand All @@ -13,11 +15,14 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
/// A base type for conventions that configure model aspects based on whether the member type
/// is a non-nullable reference type.
/// </summary>
public abstract class NonNullableConventionBase
public abstract class NonNullableConventionBase : IModelFinalizedConvention
{
// For the interpretation of nullability metadata, see
// https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-metadata.md

private const string StateAnnotationName = "NonNullableConventionState";
private const string NullableAttributeFullName = "System.Runtime.CompilerServices.NullableAttribute";
private Type _nullableAttrType;
private FieldInfo _nullableFlagsFieldInfo;
private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute";

/// <summary>
/// Creates a new instance of <see cref="NonNullableConventionBase" />.
Expand All @@ -33,40 +38,118 @@ protected NonNullableConventionBase([NotNull] ProviderConventionSetBuilderDepend
/// </summary>
protected virtual ProviderConventionSetBuilderDependencies Dependencies { get; }

private byte? GetNullabilityContextFlag(NonNullabilityConventionState state, Attribute[] attributes)
{
if (attributes.FirstOrDefault(a => a.GetType().FullName == NullableContextAttributeFullName) is Attribute attribute)
{
var attributeType = attribute.GetType();

if (attributeType != state.NullableContextAttrType)
{
state.NullableContextFlagFieldInfo = attributeType.GetField("Flag");
state.NullableContextAttrType = attributeType;
}

if (state.NullableContextFlagFieldInfo?.GetValue(attribute) is byte flag)
{
return flag;
}
}

return null;
}

/// <summary>
/// Returns a value indicating whether the member type is a non-nullable reference type.
/// </summary>
/// <param name="modelBuilder"> The model builder used to build the model. </param>
/// <param name="memberInfo"> The member info. </param>
/// <returns> <c>true</c> if the member type is a non-nullable reference type. </returns>
protected virtual bool IsNonNullable([NotNull] MemberInfo memberInfo)
protected virtual bool IsNonNullable(
[NotNull] IConventionModelBuilder modelBuilder,
[NotNull] MemberInfo memberInfo)
{
var state = GetOrInitializeState(modelBuilder);

// For C# 8.0 nullable types, the C# currently synthesizes a NullableAttribute that expresses nullability into assemblies
// it produces. If the model is spread across more than one assembly, there will be multiple versions of this attribute,
// so look for it by name, caching to avoid reflection on every check.
// Note that this may change - if https://github.com/dotnet/corefx/issues/36222 is done we can remove all of this.
if (!(Attribute.GetCustomAttributes(memberInfo, true)
.FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName)
is { } attribute))

// First look for NullableAttribute on the member itself
if (Attribute.GetCustomAttributes(memberInfo, true)
.FirstOrDefault(a => a.GetType().FullName == NullableAttributeFullName) is Attribute attribute)
{
var attributeType = attribute.GetType();

if (attributeType != state.NullableAttrType)
{
state.NullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
state.NullableAttrType = attributeType;
}

if (state.NullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags
&& flags.FirstOrDefault() == 1)
{
return true;
}
}

// No attribute on the member, try to find a NullableContextAttribute on the declaring type
var type = memberInfo.DeclaringType;
if (type != null)
{
return false;
if (state.TypeNonNullabilityContextCache.TryGetValue(type, out var cachedTypeNonNullable))
{
return cachedTypeNonNullable;
}

var typeContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(type));
if (typeContextFlag.HasValue)
{
return state.TypeNonNullabilityContextCache[type] = typeContextFlag.Value == 1;
}
}

var attributeType = attribute.GetType();
if (attributeType != _nullableAttrType)
// Not found at the type level, try at the module level
var module = memberInfo.Module;
if (!state.ModuleNonNullabilityContextCache.TryGetValue(module, out var moduleNonNullable))
{
_nullableFlagsFieldInfo = attributeType.GetField("NullableFlags");
_nullableAttrType = attributeType;
var moduleContextFlag = GetNullabilityContextFlag(state, Attribute.GetCustomAttributes(memberInfo.Module));
moduleNonNullable = state.ModuleNonNullabilityContextCache[module] =
moduleContextFlag.HasValue && moduleContextFlag == 1;
}

// For the interpretation of NullableFlags, see
// https://github.com/dotnet/roslyn/blob/master/docs/features/nullable-reference-types.md#annotations
if (_nullableFlagsFieldInfo?.GetValue(attribute) is byte[] flags
&& flags.FirstOrDefault() == 1)
if (type != null)
{
return true;
state.TypeNonNullabilityContextCache[type] = moduleNonNullable;
}

return false;
return moduleNonNullable;
}

private NonNullabilityConventionState GetOrInitializeState(IConventionModelBuilder modelBuilder)
=> (NonNullabilityConventionState)(
modelBuilder.Metadata.FindAnnotation(StateAnnotationName) ??
modelBuilder.Metadata.AddAnnotation(StateAnnotationName, new NonNullabilityConventionState())
).Value;

/// <summary>
/// Called after a model is finalized. Removes the cached state annotation used by this convention.
/// </summary>
/// <param name="modelBuilder"> The builder for the model. </param>
/// <param name="context"> Additional information associated with convention execution. </param>
public virtual void ProcessModelFinalized(IConventionModelBuilder modelBuilder, IConventionContext<IConventionModelBuilder> context)
=> modelBuilder.Metadata.RemoveAnnotation(StateAnnotationName);

private class NonNullabilityConventionState
{
public Type NullableAttrType;
public Type NullableContextAttrType;
public FieldInfo NullableFlagsFieldInfo;
public FieldInfo NullableContextFlagFieldInfo;
public Dictionary<Type, bool> TypeNonNullabilityContextCache { get; } = new Dictionary<Type, bool>();
public Dictionary<Module, bool> ModuleNonNullabilityContextCache { get; } = new Dictionary<Module, bool>();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ public virtual void ProcessNavigationAdded(
Check.NotNull(relationshipBuilder, nameof(relationshipBuilder));
Check.NotNull(navigation, nameof(navigation));

if (!IsNonNullable(navigation)
|| navigation.IsCollection())
var modelBuilder = relationshipBuilder.ModelBuilder;

if (!IsNonNullable(modelBuilder, navigation) || navigation.IsCollection())
{
return;
}
Expand All @@ -51,7 +52,7 @@ public virtual void ProcessNavigationAdded(
var inverse = navigation.FindInverse();
if (inverse != null)
{
if (IsNonNullable(inverse))
if (IsNonNullable(modelBuilder, inverse))
{
Dependencies.Logger.NonNullableReferenceOnBothNavigations(navigation, inverse);
return;
Expand Down Expand Up @@ -82,9 +83,9 @@ public virtual void ProcessNavigationAdded(
context.StopProcessingIfChanged(relationshipBuilder.Metadata.DependentToPrincipal);
}

private bool IsNonNullable(IConventionNavigation navigation)
private bool IsNonNullable(IConventionModelBuilder modelBuilder, IConventionNavigation navigation)
=> navigation.DeclaringEntityType.HasClrType()
&& navigation.DeclaringEntityType.GetRuntimeProperties().Find(navigation.Name) is PropertyInfo propertyInfo
&& IsNonNullable(propertyInfo);
&& IsNonNullable(modelBuilder, propertyInfo);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private void Process(IConventionPropertyBuilder propertyBuilder)
// If the model is spread across multiple assemblies, it may contain different NullableAttribute types as
// the compiler synthesizes them for each assembly.
if (propertyBuilder.Metadata.GetIdentifyingMemberInfo() is MemberInfo memberInfo
&& IsNonNullable(memberInfo))
&& IsNonNullable(propertyBuilder.ModelBuilder, memberInfo))
{
propertyBuilder.IsRequired(true);
}
Expand Down

0 comments on commit 0c4d97e

Please sign in to comment.