Skip to content

Commit

Permalink
Reference Include for in-memory provider
Browse files Browse the repository at this point in the history
Part of #16963
  • Loading branch information
ajcvickers committed Aug 7, 2019
1 parent fb61241 commit 06461a1
Show file tree
Hide file tree
Showing 8 changed files with 764 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public partial class InMemoryShapedQueryCompilingExpressionVisitor
{
private class CustomShaperCompilingExpressionVisitor : ExpressionVisitor
{
private readonly bool _tracking;

public CustomShaperCompilingExpressionVisitor(bool tracking)
{
_tracking = tracking;
}

private static readonly MethodInfo _includeReferenceMethodInfo
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(IncludeReference));

private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>(
QueryContext queryContext,
TEntity entity,
TIncludedEntity relatedEntity,
INavigation navigation,
INavigation inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : TEntity
{
if (entity is TIncludingEntity includingEntity)
{
if (trackingQuery)
{
// For non-null relatedEntity StateManager will set the flag
if (relatedEntity == null)
{
queryContext.StateManager.TryGetEntry(includingEntity).SetIsLoaded(navigation);
}
}
else
{
SetIsLoadedNoTracking(includingEntity, navigation);
if (relatedEntity != null)
{
fixup(includingEntity, relatedEntity);
if (inverseNavigation != null
&& !inverseNavigation.IsCollection())
{
SetIsLoadedNoTracking(relatedEntity, inverseNavigation);
}
}
}
}
}

private static void SetIsLoadedNoTracking(object entity, INavigation navigation)
=> ((ILazyLoader)(navigation
.DeclaringEntityType
.GetServiceProperties()
.FirstOrDefault(p => p.ClrType == typeof(ILazyLoader)))
?.GetGetter().GetClrValue(entity))
?.SetLoaded(entity, navigation.Name);

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is IncludeExpression includeExpression)
{
Debug.Assert(
!includeExpression.Navigation.IsCollection(),
"Only reference include should be present in tree");

var entityClrType = includeExpression.EntityExpression.Type;
var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType;
var inverseNavigation = includeExpression.Navigation.FindInverse();
var relatedEntityClrType = includeExpression.Navigation.GetTargetType().ClrType;
if (includingClrType != entityClrType
&& includingClrType.IsAssignableFrom(entityClrType))
{
includingClrType = entityClrType;
}

return Expression.Call(
_includeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType),
QueryCompilationContext.QueryContextParameter,
// We don't need to visit entityExpression since it is supposed to be a parameterExpression only
includeExpression.EntityExpression,
includeExpression.NavigationExpression,
Expression.Constant(includeExpression.Navigation),
Expression.Constant(inverseNavigation, typeof(INavigation)),
Expression.Constant(
GenerateFixup(
includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()),
Expression.Constant(_tracking));
}

return base.VisitExtension(extensionExpression);
}

private static LambdaExpression GenerateFixup(
Type entityType,
Type relatedEntityType,
INavigation navigation,
INavigation inverseNavigation)
{
var entityParameter = Expression.Parameter(entityType);
var relatedEntityParameter = Expression.Parameter(relatedEntityType);
var expressions = new List<Expression>
{
navigation.IsCollection()
? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation)
: AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation)
};

if (inverseNavigation != null)
{
expressions.Add(
inverseNavigation.IsCollection()
? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation)
: AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation));

}

return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter);
}

private static Expression AssignReferenceNavigation(
ParameterExpression entity,
ParameterExpression relatedEntity,
INavigation navigation)
{
return entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity);
}

private static Expression AddToCollectionNavigation(
ParameterExpression entity,
ParameterExpression relatedEntity,
INavigation navigation)
=> Expression.Call(
Expression.Constant(navigation.GetCollectionAccessor()),
_collectionAccessorAddMethodInfo,
entity,
relatedEntity,
Expression.Constant(true));

private static readonly MethodInfo _collectionAccessorAddMethodInfo
= typeof(IClrCollectionAccessor).GetTypeInfo()
.GetDeclaredMethod(nameof(IClrCollectionAccessor.Add));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ protected override Expression VisitExtension(Expression extensionExpression)
}
}

if (extensionExpression is IncludeExpression includeExpression)
{
return _clientEval
? base.VisitExtension(includeExpression)
: null;
}

throw new InvalidOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class InMemoryShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingExpressionVisitor
public partial class InMemoryShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingExpressionVisitor
{
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;
private static readonly ConstructorInfo _valueBufferConstructor
= typeof(ValueBuffer).GetConstructors().Single(ci => ci.GetParameters().Length == 1);

public InMemoryShapedQueryCompilingExpressionVisitor(
QueryCompilationContext queryCompilationContext,
Expand Down Expand Up @@ -53,19 +51,20 @@ protected override Expression VisitExtension(Expression extensionExpression)

protected override Expression VisitShapedQueryExpression(ShapedQueryExpression shapedQueryExpression)
{
var shaperBody = InjectEntityMaterializers(shapedQueryExpression.ShaperExpression);
var inMemoryQueryExpression = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression;

var innerEnumerable = Visit(shapedQueryExpression.QueryExpression);
var shaper = new ShaperExpressionProcessingExpressionVisitor(inMemoryQueryExpression)
.Inject(shapedQueryExpression.ShaperExpression);

var inMemoryQueryExpression = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression;
shaper = InjectEntityMaterializers(shaper);

var newBody = new InMemoryProjectionBindingRemovingExpressionVisitor(inMemoryQueryExpression)
.Visit(shaperBody);
var innerEnumerable = Visit(inMemoryQueryExpression);

var shaperLambda = Expression.Lambda(
newBody,
QueryCompilationContext.QueryContextParameter,
inMemoryQueryExpression.ValueBufferParameter);
shaper = new InMemoryProjectionBindingRemovingExpressionVisitor(inMemoryQueryExpression).Visit(shaper);

shaper = new CustomShaperCompilingExpressionVisitor(IsTracking).Visit(shaper);

var shaperLambda = (LambdaExpression)shaper;

return Expression.New(
(IsAsync
Expand Down Expand Up @@ -263,6 +262,7 @@ public ValueTask DisposeAsync()
private class InMemoryProjectionBindingRemovingExpressionVisitor : ExpressionVisitor
{
private readonly InMemoryQueryExpression _queryExpression;

private readonly IDictionary<ParameterExpression, IDictionary<IProperty, int>> _materializationContextBindings
= new Dictionary<ParameterExpression, IDictionary<IProperty, int>>();

Expand All @@ -284,7 +284,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
_materializationContextBindings[parameterExpression]
= (IDictionary<IProperty, int>)GetProjectionIndex(projectionBindingExpression);

var updatedExpression = Expression.New(newExpression.Constructor,
var updatedExpression = Expression.New(
newExpression.Constructor,
Expression.Constant(ValueBuffer.Empty),
newExpression.Arguments[1]);

Expand All @@ -300,7 +301,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod)
{
var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value;
var indexMap = _materializationContextBindings[(ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object];
var indexMap =
_materializationContextBindings[
(ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object];

return Expression.Call(
methodCallExpression.Method,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class ShaperExpressionProcessingExpressionVisitor : ExpressionVisitor
{
private readonly InMemoryQueryExpression _queryExpression;

private readonly IDictionary<Expression, ParameterExpression> _mapping = new Dictionary<Expression, ParameterExpression>();
private readonly List<ParameterExpression> _variables = new List<ParameterExpression>();
private readonly List<Expression> _expressions = new List<Expression>();

public ShaperExpressionProcessingExpressionVisitor(
InMemoryQueryExpression queryExpression)
{
_queryExpression = queryExpression;
}

public virtual Expression Inject(Expression expression)
{
var result = Visit(expression);

if (_expressions.All(e => e.NodeType == ExpressionType.Assign))
{
result = new ReplacingExpressionVisitor(_expressions.Cast<BinaryExpression>()
.ToDictionary(e => e.Left, e => e.Right)).Visit(result);
}
else
{
_expressions.Add(result);
result = Expression.Block(_variables, _expressions);
}

return ConvertToLambda(result, Expression.Parameter(result.Type, "result"));
}

private LambdaExpression ConvertToLambda(Expression result, ParameterExpression resultParameter)
=> Expression.Lambda(
result,
QueryCompilationContext.QueryContextParameter,
_queryExpression.ValueBufferParameter);

protected override Expression VisitExtension(Expression extensionExpression)
{
switch (extensionExpression)
{
case EntityShaperExpression entityShaperExpression:
{
var key = GenerateKey((ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression);
if (!_mapping.TryGetValue(key, out var variable))
{
variable = Expression.Parameter(entityShaperExpression.EntityType.ClrType);
_variables.Add(variable);
_expressions.Add(Expression.Assign(variable, entityShaperExpression));
_mapping[key] = variable;
}

return variable;
}

case ProjectionBindingExpression projectionBindingExpression:
{
var key = GenerateKey(projectionBindingExpression);
if (!_mapping.TryGetValue(key, out var variable))
{
variable = Expression.Parameter(projectionBindingExpression.Type);
_variables.Add(variable);
_expressions.Add(Expression.Assign(variable, projectionBindingExpression));
_mapping[key] = variable;
}

return variable;
}

case IncludeExpression includeExpression:
{
var entity = Visit(includeExpression.EntityExpression);
_expressions.Add(
includeExpression.Update(
entity,
Visit(includeExpression.NavigationExpression)));

return entity;
}
}

return base.VisitExtension(extensionExpression);
}

private Expression GenerateKey(ProjectionBindingExpression projectionBindingExpression)
=> projectionBindingExpression.ProjectionMember != null
? _queryExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: projectionBindingExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ public class InMemoryComplianceTest : ComplianceTestBase
typeof(ComplexNavigationsQueryTestBase<>),
typeof(GearsOfWarQueryTestBase<>),
typeof(IncludeAsyncTestBase<>),
typeof(IncludeOneToOneTestBase<>),
typeof(IncludeTestBase<>),
typeof(InheritanceRelationshipsQueryTestBase<>),
typeof(InheritanceTestBase<>),
typeof(NullKeysTestBase<>),
Expand Down
Loading

0 comments on commit 06461a1

Please sign in to comment.