Skip to content

Commit

Permalink
Stop marking the inverse navigation as loaded when loading a many-to-…
Browse files Browse the repository at this point in the history
…many navigation

Fixes #23475

This is a targeted patch fix which special cases many-to-many loading. I will file an issue for a more general solution using an appropriate general update to query.
  • Loading branch information
ajcvickers committed Dec 11, 2020
1 parent a87b7a2 commit b706ce6
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ private void AddInclude(
Expression.Constant(navigation),
Expression.Constant(inverseNavigation, typeof(INavigation)),
Expression.Constant(fixup),
Expression.Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType))));
Expression.Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType)),
Expression.Constant(includeExpression.SetLoaded)));
}

private static readonly MethodInfo _includeReferenceMethodInfo
Expand All @@ -409,7 +410,8 @@ private static void IncludeReference<TIncludingEntity, TIncludedEntity>(
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
Action<TIncludingEntity> _)
Action<TIncludingEntity> _,
bool __)
{
if (entity == null
|| !navigation.DeclaringEntityType.IsAssignableFrom(entityType))
Expand Down Expand Up @@ -454,7 +456,8 @@ private static void IncludeCollection<TIncludingEntity, TIncludedEntity>(
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
Action<TIncludingEntity> initialize)
Action<TIncludingEntity> initialize,
bool setLoaded)
{
if (entity == null
|| !navigation.DeclaringEntityType.IsAssignableFrom(entityType))
Expand Down Expand Up @@ -485,9 +488,13 @@ private static void IncludeCollection<TIncludingEntity, TIncludedEntity>(
}
else
{
if (setLoaded)
{
#pragma warning disable EF1001 // Internal EF Core API usage.
entry.SetIsLoaded(navigation);
entry.SetIsLoaded(navigation);
#pragma warning restore EF1001 // Internal EF Core API usage.
}

if (relatedEntities != null)
{
using var enumerator = relatedEntities.GetEnumerator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ private static void IncludeCollection<TEntity, TIncludingEntity, TIncludedEntity
INavigationBase navigation,
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
bool trackingQuery,
bool setLoaded)
where TIncludingEntity : class, TEntity
where TEntity : class
where TIncludedEntity : class
Expand All @@ -97,13 +98,16 @@ private static void IncludeCollection<TEntity, TIncludingEntity, TIncludedEntity
var collectionAccessor = navigation.GetCollectionAccessor();
collectionAccessor.GetOrCreate(includingEntity, forMaterialization: true);

if (trackingQuery)
{
queryContext.SetNavigationIsLoaded(entity, navigation);
}
else
if (setLoaded)
{
navigation.SetIsLoadedWhenNoTracking(entity);
if (trackingQuery)
{
queryContext.SetNavigationIsLoaded(entity, navigation);
}
else
{
navigation.SetIsLoadedWhenNoTracking(entity);
}
}

foreach (var valueBuffer in innerValueBuffers)
Expand Down Expand Up @@ -178,7 +182,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
Expression.Constant(
GenerateFixup(
includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()),
Expression.Constant(_tracking));
Expression.Constant(_tracking),
Expression.Constant(includeExpression.SetLoaded));
}

return Expression.Call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
Expression.Constant(outerIdentifierLambda.Compile()),
Expression.Constant(navigation),
Expression.Constant(navigation.GetCollectionAccessor()),
Expression.Constant(_isTracking)));
Expression.Constant(_isTracking),
Expression.Constant(includeExpression.SetLoaded)));

var relatedEntityType = innerShaper.ReturnType;
var inverseNavigation = navigation.Inverse;
Expand Down Expand Up @@ -629,7 +630,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
Expression.Constant(parentIdentifierLambda.Compile()),
Expression.Constant(navigation),
Expression.Constant(navigation.GetCollectionAccessor()),
Expression.Constant(_isTracking)));
Expression.Constant(_isTracking),
Expression.Constant(includeExpression.SetLoaded)));

var relatedEntityType = innerShaper.ReturnType;
var inverseNavigation = navigation.Inverse;
Expand Down Expand Up @@ -1207,20 +1209,24 @@ private static void InitializeIncludeCollection<TParent, TNavigationEntity>(
Func<QueryContext, DbDataReader, object[]> outerIdentifier,
INavigationBase navigation,
IClrCollectionAccessor clrCollectionAccessor,
bool trackingQuery)
bool trackingQuery,
bool setLoaded)
where TParent : class
where TNavigationEntity : class, TParent
{
object collection = null;
if (entity is TNavigationEntity)
{
if (trackingQuery)
{
queryContext.SetNavigationIsLoaded(entity, navigation);
}
else
if (setLoaded)
{
navigation.SetIsLoadedWhenNoTracking(entity);
if (trackingQuery)
{
queryContext.SetNavigationIsLoaded(entity, navigation);
}
else
{
navigation.SetIsLoadedWhenNoTracking(entity);
}
}

collection = clrCollectionAccessor.GetOrCreate(entity, forMaterialization: true);
Expand Down Expand Up @@ -1361,20 +1367,24 @@ private static void InitializeSplitIncludeCollection<TParent, TNavigationEntity>
Func<QueryContext, DbDataReader, object[]> parentIdentifier,
INavigationBase navigation,
IClrCollectionAccessor clrCollectionAccessor,
bool trackingQuery)
bool trackingQuery,
bool setLoaded)
where TParent : class
where TNavigationEntity : class, TParent
{
object collection = null;
if (entity is TNavigationEntity)
{
if (trackingQuery)
{
queryContext.SetNavigationIsLoaded(entity, navigation);
}
else
if (setLoaded)
{
navigation.SetIsLoadedWhenNoTracking(entity);
if (trackingQuery)
{
queryContext.SetNavigationIsLoaded(entity, navigation);
}
else
{
navigation.SetIsLoadedWhenNoTracking(entity);
}
}

collection = clrCollectionAccessor.GetOrCreate(entity, forMaterialization: true);
Expand Down
25 changes: 25 additions & 0 deletions src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,15 @@ internal static readonly MethodInfo IncludeMethodInfo
&& mi.GetParameters().Any(
pi => pi.Name == "navigationPropertyPath" && pi.ParameterType != typeof(string)));

internal static readonly MethodInfo NotQuiteIncludeMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(NotQuiteInclude))
.Single(
mi =>
mi.GetGenericArguments().Count() == 2
&& mi.GetParameters().Any(
pi => pi.Name == "navigationPropertyPath" && pi.ParameterType != typeof(string)));

/// <summary>
/// Specifies related entities to include in the query results. The navigation property to be included is specified starting with the
/// type of entity being queried (<typeparamref name="TEntity" />). If you wish to include additional types based on the navigation
Expand Down Expand Up @@ -2443,6 +2452,22 @@ source.Provider is EntityQueryProvider
: source);
}

// A version of Include that doesn't set the navigation as loaded
internal static IIncludableQueryable<TEntity, TProperty> NotQuiteInclude<TEntity, TProperty>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] Expression<Func<TEntity, TProperty>> navigationPropertyPath)
where TEntity : class
{
return new IncludableQueryable<TEntity, TProperty>(
source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(
instance: null,
method: NotQuiteIncludeMethodInfo.MakeGenericMethod(typeof(TEntity), typeof(TProperty)),
arguments: new[] { source.Expression, Expression.Quote(navigationPropertyPath) }))
: source);
}

internal static readonly MethodInfo ThenIncludeAfterEnumerableMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(ThenInclude))
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore/Internal/ManyToManyLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ private IQueryable<TEntity> Query(
// .AsTracking()
// .Where(e => e.Id == left.Id)
// .SelectMany(e => e.TwoSkip)
// .Include(e => e.OneSkip.Where(e => e.Id == left.Id));
// .NotQuiteInclude(e => e.OneSkip.Where(e => e.Id == left.Id));

var queryRoot = _skipNavigation.DeclaringEntityType.HasSharedClrType
? context.Set<TSourceEntity>(_skipNavigation.DeclaringEntityType.Name)
Expand All @@ -157,7 +157,7 @@ private IQueryable<TEntity> Query(
.AsTracking()
.Where(BuildWhereLambda(loadProperties, new ValueBuffer(keyValues)))
.SelectMany(BuildSelectManyLambda(_skipNavigation))
.Include(BuildIncludeLambda(_skipNavigation.Inverse, loadProperties, new ValueBuffer(keyValues)))
.NotQuiteInclude(BuildIncludeLambda(_skipNavigation.Inverse, loadProperties, new ValueBuffer(keyValues)))
.AsQueryable();
}

Expand Down
27 changes: 25 additions & 2 deletions src/EFCore/Query/IncludeExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace Microsoft.EntityFrameworkCore.Query
public class IncludeExpression : Expression, IPrintableExpression
{
/// <summary>
/// Creates a new instance of the <see cref="IncludeExpression" /> class.
/// Creates a new instance of the <see cref="IncludeExpression" /> class. The navigation will be set
/// as loaded after completing the Include.
/// </summary>
/// <param name="entityExpression"> An expression to get entity which is performing include. </param>
/// <param name="navigationExpression"> An expression to get included navigation element. </param>
Expand All @@ -30,6 +31,22 @@ public IncludeExpression(
[NotNull] Expression entityExpression,
[NotNull] Expression navigationExpression,
[NotNull] INavigationBase navigation)
: this(entityExpression, navigationExpression, navigation, setLoaded: true)
{
}

/// <summary>
/// Creates a new instance of the <see cref="IncludeExpression" /> class.
/// </summary>
/// <param name="entityExpression"> An expression to get entity which is performing include. </param>
/// <param name="navigationExpression"> An expression to get included navigation element. </param>
/// <param name="navigation"> The navigation for this include operation. </param>
/// <param name="setLoaded"> True if the navigation will be marked as loaded. </param>
public IncludeExpression(
[NotNull] Expression entityExpression,
[NotNull] Expression navigationExpression,
[NotNull] INavigationBase navigation,
bool setLoaded)
{
Check.NotNull(entityExpression, nameof(entityExpression));
Check.NotNull(navigationExpression, nameof(navigationExpression));
Expand All @@ -39,6 +56,7 @@ public IncludeExpression(
NavigationExpression = navigationExpression;
Navigation = navigation;
Type = EntityExpression.Type;
SetLoaded = setLoaded;
}

/// <summary>
Expand All @@ -56,6 +74,11 @@ public IncludeExpression(
/// </summary>
public virtual INavigationBase Navigation { get; }

/// <summary>
/// True if the navigation will be marked as loaded.
/// </summary>
public virtual bool SetLoaded { get; }

/// <inheritdoc />
public sealed override ExpressionType NodeType
=> ExpressionType.Extension;
Expand Down Expand Up @@ -87,7 +110,7 @@ public virtual IncludeExpression Update([NotNull] Expression entityExpression, [
Check.NotNull(navigationExpression, nameof(navigationExpression));

return entityExpression != EntityExpression || navigationExpression != NavigationExpression
? new IncludeExpression(entityExpression, navigationExpression, Navigation)
? new IncludeExpression(entityExpression, navigationExpression, Navigation, SetLoaded)
: this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ private Expression ExpandIncludesHelper(Expression root, EntityReference entityR
}
}

result = new IncludeExpression(result, included, navigationBase);
result = new IncludeExpression(result, included, navigationBase, entityReference.SetLoaded);
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public EntityReference(IEntityType entityType)
public bool IsOptional { get; private set; }
public IncludeTreeNode IncludePaths { get; private set; }
public IncludeTreeNode LastIncludeTreeNode { get; private set; }
public bool SetLoaded { get; private set; } = true;

public override ExpressionType NodeType
=> ExpressionType.Extension;
Expand Down Expand Up @@ -57,6 +58,9 @@ public void SetLastInclude(IncludeTreeNode lastIncludeTree)
public void MarkAsOptional()
=> IsOptional = true;

public void SuppressSettingLoaded()
=> SetLoaded = false;

void IPrintableExpression.Print(ExpressionPrinter expressionPrinter)
{
Check.NotNull(expressionPrinter, nameof(expressionPrinter));
Expand Down
26 changes: 22 additions & 4 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,25 @@ when QueryableMethods.IsSumWithSelector(method):
methodCallExpression.Type.TryGetSequenceType());

case nameof(EntityFrameworkQueryableExtensions.Include):
return ProcessInclude(
source,
methodCallExpression.Arguments[1],
thenInclude: false,
setLoaded: true);

case nameof(EntityFrameworkQueryableExtensions.ThenInclude):
return ProcessInclude(
source,
methodCallExpression.Arguments[1],
string.Equals(
method.Name,
nameof(EntityFrameworkQueryableExtensions.ThenInclude)));
thenInclude: true,
setLoaded: true);

case nameof(EntityFrameworkQueryableExtensions.NotQuiteInclude):
return ProcessInclude(
source,
methodCallExpression.Arguments[1],
thenInclude: false,
setLoaded: false);

case nameof(Queryable.GroupBy)
when genericMethod == QueryableMethods.GroupByWithKeySelector:
Expand Down Expand Up @@ -823,7 +835,8 @@ private NavigationExpansionExpression ProcessGroupBy(
return new NavigationExpansionExpression(result, navigationTree, navigationTree, parameterName);
}

private NavigationExpansionExpression ProcessInclude(NavigationExpansionExpression source, Expression expression, bool thenInclude)
private NavigationExpansionExpression ProcessInclude(
NavigationExpansionExpression source, Expression expression, bool thenInclude, bool setLoaded)
{
if (source.PendingSelector is NavigationTreeExpression navigationTree
&& navigationTree.Value is EntityReference entityReference)
Expand Down Expand Up @@ -890,6 +903,11 @@ private NavigationExpansionExpression ProcessInclude(NavigationExpansionExpressi
}

entityReference.SetLastInclude(lastIncludeTree);

if (!setLoaded)
{
entityReference.SuppressSettingLoaded();
}
}

return source;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& method.GetGenericMethodDefinition() is MethodInfo genericMethod
&& (genericMethod == EntityFrameworkQueryableExtensions.IncludeMethodInfo
|| genericMethod == EntityFrameworkQueryableExtensions.ThenIncludeAfterEnumerableMethodInfo
|| genericMethod == EntityFrameworkQueryableExtensions.ThenIncludeAfterReferenceMethodInfo))
|| genericMethod == EntityFrameworkQueryableExtensions.ThenIncludeAfterReferenceMethodInfo
|| genericMethod == EntityFrameworkQueryableExtensions.NotQuiteIncludeMethodInfo))
{
var includeLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
if (includeLambda.ReturnType.IsGenericType
Expand Down
Loading

0 comments on commit b706ce6

Please sign in to comment.