Skip to content

Commit

Permalink
Cosmos: Properly handle nested owned types on owned collections
Browse files Browse the repository at this point in the history
Consolidate Enumerable methods info
Print the declaring type name for static method calls

Fixes #18265
  • Loading branch information
AndriySvyryd committed Oct 21, 2019
1 parent bf62b29 commit f70c3b6
Show file tree
Hide file tree
Showing 29 changed files with 770 additions and 306 deletions.
8 changes: 2 additions & 6 deletions src/EFCore.Cosmos/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
{
Expand All @@ -16,11 +17,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
/// </summary>
public class ContainsTranslator : IMethodCallTranslator
{
private static readonly MethodInfo _containsMethod = typeof(Enumerable).GetTypeInfo()
.GetDeclaredMethods(nameof(Enumerable.Contains))
.Single(mi => mi.GetParameters().Length == 2)
.GetGenericMethodDefinition();

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
Expand All @@ -43,7 +39,7 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory)
public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(_containsMethod))
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains))
{
return _sqlExpressionFactory.In(arguments[1], arguments[0], false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
Expand All @@ -31,6 +34,12 @@ private readonly IDictionary<ProjectionMember, Expression> _projectionMapping

private readonly Stack<ProjectionMember> _projectionMembers = new Stack<ProjectionMember>();

private readonly IDictionary<ParameterExpression, CollectionShaperExpression> _collectionShaperMapping
= new Dictionary<ParameterExpression, CollectionShaperExpression>();

private readonly Stack<INavigation> _includedNavigations
= new Stack<INavigation>();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -109,12 +118,17 @@ public override Expression Visit(Expression expression)
return expression;

case ParameterExpression parameterExpression:
if (_collectionShaperMapping.ContainsKey(parameterExpression))
{
return parameterExpression;
}

return Expression.Call(
_getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type),
QueryCompilationContext.QueryContextParameter,
Expression.Constant(parameterExpression.Name));

case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression:
case MaterializeCollectionNavigationExpression _:
return base.Visit(expression);
//return _selectExpression.AddCollectionProjection(
// _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(
Expand Down Expand Up @@ -269,6 +283,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

break;

case ParameterExpression parameterExpression:
if (!_collectionShaperMapping.TryGetValue(parameterExpression, out var collectionShaper))
{
return null;
}

shaperExpression = (EntityShaperExpression)collectionShaper.InnerShaper;
break;

default:
return null;
}
Expand All @@ -289,13 +312,25 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
}

var navigationProjection = innerEntityProjection.BindMember(
memberName, visitedSource.Type, clientEval: true, out var propertyBase);

if (!(propertyBase is INavigation navigation)
|| !navigation.IsEmbedded())
Expression navigationProjection;
var navigation = _includedNavigations.FirstOrDefault(n => n.Name == memberName);
if (navigation == null)
{
return null;
navigationProjection = innerEntityProjection.BindMember(
memberName, visitedSource.Type, clientEval: true, out var propertyBase);

if (!(propertyBase is INavigation projectedNavigation)
|| !projectedNavigation.IsEmbedded())
{
return null;
}

navigation = projectedNavigation;
}
else
{
navigationProjection = innerEntityProjection.BindNavigation(navigation, clientEval: true);
}

switch (navigationProjection)
Expand Down Expand Up @@ -326,6 +361,49 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

if (_clientEval)
{
var method = methodCallExpression.Method;
if (method.DeclaringType == typeof(Queryable)
|| method.DeclaringType == typeof(QueryableExtensions))
{
var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null;
var visitedSource = Visit(methodCallExpression.Arguments[0]);

switch (method.Name)
{
case nameof(Queryable.AsQueryable)
when genericMethod == QueryableMethods.AsQueryable:
// Unwrap AsQueryable
return visitedSource;

case nameof(Queryable.Select)
when genericMethod == QueryableMethods.Select:
var shaper = visitedSource as CollectionShaperExpression;

LambdaExpression lambda = null;
switch (methodCallExpression.Arguments[1]) {
case LambdaExpression lambdaExpression:
lambda = lambdaExpression;
break;
case UnaryExpression unaryExpression:
lambda = unaryExpression.Operand as LambdaExpression;
break;
}

_collectionShaperMapping.Add(lambda.Parameters.Single(), shaper);

var genericArguments = method.GetGenericArguments();
return Expression.Call(
EnumerableMethods.Select.MakeGenericMethod(
genericArguments[0],
genericArguments[1]),
shaper,
Visit(lambda));
}
}
}

return base.VisitMethodCall(methodCallExpression);
}

Expand Down Expand Up @@ -362,26 +440,29 @@ protected override Expression VisitExtension(Expression extensionExpression)
}

case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression:
if (materializeCollectionNavigationExpression.Navigation.IsEmbedded())
{
var subquery = materializeCollectionNavigationExpression.Subquery;
// Unwrap AsQueryable around the subquery if present
if (subquery is MethodCallExpression innerMethodCall
&& innerMethodCall.Method.IsGenericMethod
&& innerMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable)
{
subquery = innerMethodCall.Arguments[0];
}
return materializeCollectionNavigationExpression.Navigation.IsEmbedded()
? base.Visit(materializeCollectionNavigationExpression.Subquery)
: base.VisitExtension(materializeCollectionNavigationExpression);

return base.Visit(subquery);
case IncludeExpression includeExpression:
if (!_clientEval)
{
return null;
}
else

if (!includeExpression.Navigation.IsEmbedded())
{
return base.VisitExtension(materializeCollectionNavigationExpression);
throw new InvalidOperationException(
"Non-embedded IncludeExpression is not supported: " + includeExpression.Print());
}

case IncludeExpression includeExpression:
return _clientEval ? base.VisitExtension(includeExpression) : null;
_includedNavigations.Push(includeExpression.Navigation);

var newIncludeExpression = base.VisitExtension(includeExpression);

_includedNavigations.Pop();

return newIncludeExpression;

default:
throw new InvalidOperationException(CoreStrings.QueryFailed(extensionExpression.Print(), GetType().Name));
Expand Down
Loading

0 comments on commit f70c3b6

Please sign in to comment.