Skip to content

Commit

Permalink
Merge pull request #16831 from aspnet/ProjectedCollections
Browse files Browse the repository at this point in the history
[preview8] Handle projected collections.
  • Loading branch information
ajcvickers authored Jul 30, 2019
2 parents 76cacc2 + 14587ca commit 7d01acc
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
_currentEntityIndex++;

var resultType = typeof(IEnumerable<>).MakeGenericType(collectionShaperExpression.ElementType);

var jArrayVariable = Expression.Variable(
typeof(JArray),
"jArray" + _currentEntityIndex);
Expand All @@ -156,12 +154,12 @@ protected override Expression VisitExtension(Expression extensionExpression)

Expression.Condition(
Expression.Equal(jArrayVariable, Expression.Constant(null, jArrayVariable.Type)),
Expression.Constant(null, resultType),
Expression.Convert(collectionShaperExpression, resultType))
Expression.Constant(null, collectionShaperExpression.Type),
collectionShaperExpression)
};

return Expression.Block(
resultType,
collectionShaperExpression.Type,
variables,
expressions);
}
Expand Down Expand Up @@ -367,8 +365,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
var projection = GetProjection(projectionBindingExpression);
objectArrayProjection = (ObjectArrayProjectionExpression)projection.Expression;
break;
case ObjectArrayProjectionExpression arrayProjectionExpression:
objectArrayProjection = arrayProjectionExpression;
case ObjectArrayProjectionExpression objectArrayProjectionExpression:
objectArrayProjection = objectArrayProjectionExpression;
break;
default:
throw new InvalidOperationException();
Expand All @@ -389,21 +387,29 @@ protected override Expression VisitExtension(Expression extensionExpression)
var innerShaper = Visit(collectionShaperExpression.InnerShaper);
_ordinalParameter = previousOrdinalParameter;

return Expression.Call(
var entities = Expression.Call(
_selectMethodInfo.MakeGenericMethod(typeof(JObject), innerShaper.Type),
Expression.Call(
_castMethodInfo.MakeGenericMethod(typeof(JObject)),
jArray),
Expression.Lambda(innerShaper, jObjectParameter, ordinalParameter));

var navigation = collectionShaperExpression.Navigation;
return Expression.Call(
_populateCollectionMethodInfo.MakeGenericMethod(navigation.GetTargetType().ClrType, navigation.ClrType),
Expression.Constant(navigation.GetCollectionAccessor()),
entities);
}

case IncludeExpression includeExpression:
{
var navigation = includeExpression.Navigation;
var fk = navigation.ForeignKey;
if (includeExpression.Navigation.IsDependentToPrincipal()
|| fk.DeclaringEntityType.IsDocumentRoot())
{
throw new InvalidOperationException("Non-embedded IncludeExpression " + new ExpressionPrinter().Print(includeExpression));
throw new InvalidOperationException(
"Non-embedded IncludeExpression " + new ExpressionPrinter().Print(includeExpression));
}

// These are the expressions added by JObjectInjectingExpressionVisitor
Expand All @@ -425,24 +431,25 @@ protected override Expression VisitExtension(Expression extensionExpression)
var concreteEntityTypeVariable = shaperBlock.Variables.Single(v => v.Type == typeof(IEntityType));
var inverseNavigation = navigation.FindInverse();
var fixup = GenerateFixup(
includingClrType, relatedEntityClrType, navigation, inverseNavigation);
includingClrType, relatedEntityClrType, navigation, inverseNavigation);
var initialize = GenerateInitialize(includingClrType, navigation);

var previousOwner = _ownerInfo;
_ownerInfo = (navigation.DeclaringEntityType, jObjectVariable);
var navigationExpression = Visit(includeExpression.NavigationExpression);
_ownerInfo = previousOwner;

shaperExpressions.Add(Expression.Call(
includeMethod.MakeGenericMethod(includingClrType, relatedEntityClrType),
entityEntryVariable,
instanceVariable,
concreteEntityTypeVariable,
navigationExpression,
Expression.Constant(navigation),
Expression.Constant(inverseNavigation, typeof(INavigation)),
Expression.Constant(fixup),
Expression.Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType))));
shaperExpressions.Add(
Expression.Call(
includeMethod.MakeGenericMethod(includingClrType, relatedEntityClrType),
entityEntryVariable,
instanceVariable,
concreteEntityTypeVariable,
navigationExpression,
Expression.Constant(navigation),
Expression.Constant(inverseNavigation, typeof(INavigation)),
Expression.Constant(fixup),
Expression.Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType))));

shaperExpressions.Add(instanceVariable);
shaperBlock = shaperBlock.Update(shaperBlock.Variables, shaperExpressions);
Expand All @@ -454,6 +461,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock));

return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions);
}
}

return base.VisitExtension(extensionExpression);
Expand Down Expand Up @@ -634,6 +642,24 @@ private static Expression AddToCollectionNavigation(
relatedEntity,
Expression.Constant(true));

private static readonly MethodInfo _populateCollectionMethodInfo
= typeof(CosmosProjectionBindingRemovingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(PopulateCollection));

private static TCollection PopulateCollection<TEntity, TCollection>(
IClrCollectionAccessor accessor,
IEnumerable<TEntity> entities)
{
// TODO: throw a better exception for non ICollection navigations
var collection = (ICollection<TEntity>)accessor.Create();
foreach (var entity in entities)
{
collection.Add(entity);
}

return (TCollection)collection;
}

private int GetProjectionIndex(ProjectionBindingExpression projectionBindingExpression)
=> projectionBindingExpression.ProjectionMember != null
? (int)((ConstantExpression)_selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember)).Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
using Microsoft.EntityFrameworkCore.Cosmos.Update.Internal;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Update;
using Microsoft.Extensions.DependencyInjection;
Expand Down
24 changes: 18 additions & 6 deletions src/EFCore/Query/ExpressionPrinter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -568,11 +568,6 @@ protected override Expression VisitMemberInit(MemberInitExpression memberInitExp

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (!methodCallExpression.IsEFProperty())
{
_stringBuilder.Append(methodCallExpression.Method.ReturnType.ShortDisplayName() + " ");
}

if (methodCallExpression.Object != null)
{
if (methodCallExpression.Object is BinaryExpression)
Expand All @@ -589,7 +584,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
_stringBuilder.Append(".");
}

_stringBuilder.Append(methodCallExpression.Method.Name + "(");
_stringBuilder.Append(methodCallExpression.Method.Name);
if (methodCallExpression.Method.IsGenericMethod)
{
_stringBuilder.Append("<");
var first = true;
foreach (var genericArgument in methodCallExpression.Method.GetGenericArguments())
{
if (!first)
{
_stringBuilder.Append(", ");
}
_stringBuilder.Append(genericArgument.ShortDisplayName());
first = false;
}

_stringBuilder.Append(">");
}
_stringBuilder.Append("(");

var isSimpleMethodOrProperty = _simpleMethods.Contains(methodCallExpression.Method.Name)
|| methodCallExpression.Arguments.Count < 2
Expand Down
3 changes: 1 addition & 2 deletions test/EFCore.Cosmos.FunctionalTests/EmbeddedDocumentsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ public virtual async Task Can_query_just_nested_reference()
}
}

// #12086
//[ConditionalFact]
[ConditionalFact]
public virtual async Task Can_query_just_nested_collection()
{
await using (var testDatabase = CreateTestStore())
Expand Down

0 comments on commit 7d01acc

Please sign in to comment.