From 0b64456deac9ee8d34d2a96abe77590904679ed7 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 22 Aug 2019 09:39:21 -0700 Subject: [PATCH] InMemory: Add support for collection projection Part of #16963 --- ...emoryProjectionBindingExpressionVisitor.cs | 53 +++++++++-------- ...yableMethodTranslatingExpressionVisitor.cs | 10 ++-- ....CustomShaperCompilingExpressionVisitor.cs | 57 +++++++++++++++---- ...erExpressionProcessingExpressionVisitor.cs | 29 ++++++++-- .../Query/CollectionShaperExpression.cs | 2 +- .../Query/AsyncGearsOfWarQueryInMemoryTest.cs | 2 +- .../Query/AsyncSimpleQueryInMemoryTest.cs | 48 ++-------------- .../Query/CompiledQueryInMemoryTest.cs | 28 +-------- .../Query/FiltersInMemoryTest.cs | 29 ---------- .../Query/IncludeInMemoryFixture.cs | 1 - 10 files changed, 115 insertions(+), 144 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs index e7e1c543e5e..f4725f2a962 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs @@ -7,6 +7,7 @@ using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; @@ -85,20 +86,11 @@ public override Expression Visit(Expression expression) return expression; case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression: - - var translated = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( - materializeCollectionNavigationExpression.Subquery); - - var index = _queryExpression.AddSubqueryProjection(translated, out var innerShaper); - - return new CollectionShaperExpression( - new ProjectionBindingExpression( - _queryExpression, - index, - typeof(IEnumerable)), - innerShaper, + return AddCollectionProjection( + _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + materializeCollectionNavigationExpression.Subquery), materializeCollectionNavigationExpression.Navigation, - materializeCollectionNavigationExpression.Navigation.GetTargetType().ClrType); + null); case MethodCallExpression methodCallExpression: { @@ -106,22 +98,21 @@ public override Expression Visit(Expression expression) && methodCallExpression.Method.DeclaringType == typeof(Enumerable) && methodCallExpression.Method.Name == nameof(Enumerable.ToList)) { - //var elementType = methodCallExpression.Method.GetGenericArguments()[0]; - - //var result = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression.Arguments[0]); - - //return _selectExpression.AddCollectionProjection(result, null, elementType); - throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name)); + return AddCollectionProjection( + _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + methodCallExpression.Arguments[0]), + null, + methodCallExpression.Method.GetGenericArguments()[0]); } var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); - if (subquery != null) { - //if (subquery.ResultType == ResultType.Enumerable) - //{ - // return _selectExpression.AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type); - //} + if (subquery.ResultCardinality == ResultCardinality.Enumerable) + { + return AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type); + } + throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name)); } @@ -129,7 +120,6 @@ public override Expression Visit(Expression expression) } } - var translation = _expressionTranslatingExpressionVisitor.Translate(expression); return translation == null ? base.Visit(expression) @@ -153,6 +143,19 @@ public override Expression Visit(Expression expression) return base.Visit(expression); } + private CollectionShaperExpression AddCollectionProjection( + ShapedQueryExpression subquery, INavigation navigation, Type elementType) + => new CollectionShaperExpression( + new ProjectionBindingExpression( + _queryExpression, + _queryExpression.AddSubqueryProjection( + subquery, + out var innerShaper), + typeof(IEnumerable)), + innerShaper, + navigation, + elementType); + protected override Expression VisitExtension(Expression extensionExpression) { if (extensionExpression is EntityShaperExpression entityShaperExpression) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index d20142dec52..4b0feaf449d 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -635,11 +635,11 @@ private ShapedQueryExpression TranslateScalarAggregate( inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider - .GetAggregateMethod(methodName, selector.ReturnType, parameterCount: 1) - .MakeGenericMethod(typeof(ValueBuffer)), - inMemoryQueryExpression.ServerQueryExpression, - selector); + InMemoryLinqOperatorProvider + .GetAggregateMethod(methodName, selector.ReturnType, parameterCount: 1) + .MakeGenericMethod(typeof(ValueBuffer)), + inMemoryQueryExpression.ServerQueryExpression, + selector); source.ShaperExpression = inMemoryQueryExpression.GetSingleScalarProjection(); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs index 0d287ebeae0..b9c802d8e44 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs @@ -32,6 +32,18 @@ private static readonly MethodInfo _includeCollectionMethodInfo = typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo() .GetDeclaredMethod(nameof(IncludeCollection)); + private static readonly MethodInfo _materializeCollectionMethodInfo + = typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo() + .GetDeclaredMethod(nameof(MaterializeCollection)); + + private static void SetIsLoadedNoTracking(object entity, INavigation navigation) + => ((ILazyLoader)(navigation + .DeclaringEntityType + .GetServiceProperties() + .FirstOrDefault(p => p.ClrType == typeof(ILazyLoader))) + ?.GetGetter().GetClrValue(entity)) + ?.SetLoaded(entity, navigation.Name); + private static void IncludeReference( QueryContext queryContext, TEntity entity, @@ -114,13 +126,23 @@ private static void IncludeCollection ((ILazyLoader)(navigation - .DeclaringEntityType - .GetServiceProperties() - .FirstOrDefault(p => p.ClrType == typeof(ILazyLoader))) - ?.GetGetter().GetClrValue(entity)) - ?.SetLoaded(entity, navigation.Name); + private static TCollection MaterializeCollection( + QueryContext queryContext, + IEnumerable innerValueBuffers, + Func innerShaper, + IClrCollectionAccessor clrCollectionAccessor) + where TCollection : class, ICollection + { + var collection = (TCollection)(clrCollectionAccessor?.Create() ?? new List()); + + foreach (var valueBuffer in innerValueBuffers) + { + var element = innerShaper(queryContext, valueBuffer); + collection.Add(element); + } + + return collection; + } protected override Expression VisitExtension(Expression extensionExpression) { @@ -138,12 +160,12 @@ protected override Expression VisitExtension(Expression extensionExpression) if (includeExpression.Navigation.IsCollection()) { - var collectionShaperExpression = (CollectionShaperExpression)includeExpression.NavigationExpression; + var collectionShaper = (CollectionShaperExpression)includeExpression.NavigationExpression; return Expression.Call( _includeCollectionMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), QueryCompilationContext.QueryContextParameter, - collectionShaperExpression.Projection, - Expression.Constant(((LambdaExpression)Visit(collectionShaperExpression.InnerShaper)).Compile()), + collectionShaper.Projection, + Expression.Constant(((LambdaExpression)Visit(collectionShaper.InnerShaper)).Compile()), includeExpression.EntityExpression, Expression.Constant(includeExpression.Navigation), Expression.Constant(inverseNavigation, typeof(INavigation)), @@ -166,6 +188,21 @@ protected override Expression VisitExtension(Expression extensionExpression) Expression.Constant(_tracking)); } + if (extensionExpression is CollectionShaperExpression collectionShaperExpression) + { + var elementType = collectionShaperExpression.ElementType; + var collectionType = collectionShaperExpression.Type; + + return Expression.Call( + _materializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType), + QueryCompilationContext.QueryContextParameter, + collectionShaperExpression.Projection, + Expression.Constant(((LambdaExpression)Visit(collectionShaperExpression.InnerShaper)).Compile()), + Expression.Constant( + collectionShaperExpression.Navigation?.GetCollectionAccessor(), + typeof(IClrCollectionAccessor))); + } + return base.VisitExtension(extensionExpression); } diff --git a/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs index 619ecb76edf..d0a61f7bc5a 100644 --- a/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs @@ -75,17 +75,17 @@ protected override Expression VisitExtension(Expression extensionExpression) case IncludeExpression includeExpression: { var entity = Visit(includeExpression.EntityExpression); - if (includeExpression.NavigationExpression is CollectionShaperExpression collectionShaperExpression) + if (includeExpression.NavigationExpression is CollectionShaperExpression collectionShaper) { - var innerLambda = (LambdaExpression)collectionShaperExpression.InnerShaper; + var innerLambda = (LambdaExpression)collectionShaper.InnerShaper; var innerShaper = new ShaperExpressionProcessingExpressionVisitor(null, innerLambda.Parameters[0]) .Inject(innerLambda.Body); _expressions.Add( includeExpression.Update( entity, - collectionShaperExpression.Update( - Visit(collectionShaperExpression.Projection), + collectionShaper.Update( + Visit(collectionShaper.Projection), innerShaper))); } else @@ -98,6 +98,27 @@ protected override Expression VisitExtension(Expression extensionExpression) return entity; } + + case CollectionShaperExpression collectionShaperExpression: + { + var key = GenerateKey((ProjectionBindingExpression)collectionShaperExpression.Projection); + if (!_mapping.TryGetValue(key, out var variable)) + { + var projection = Visit(collectionShaperExpression.Projection); + + variable = Expression.Parameter(collectionShaperExpression.Type); + _variables.Add(variable); + + var innerLambda = (LambdaExpression)collectionShaperExpression.InnerShaper; + var innerShaper = new ShaperExpressionProcessingExpressionVisitor(null, innerLambda.Parameters[0]) + .Inject(innerLambda.Body); + + _expressions.Add(Expression.Assign(variable, collectionShaperExpression.Update(projection, innerShaper))); + _mapping[key] = variable; + } + + return variable; + } } return base.VisitExtension(extensionExpression); diff --git a/src/EFCore/Query/CollectionShaperExpression.cs b/src/EFCore/Query/CollectionShaperExpression.cs index 93248a52650..4008ff51d53 100644 --- a/src/EFCore/Query/CollectionShaperExpression.cs +++ b/src/EFCore/Query/CollectionShaperExpression.cs @@ -19,7 +19,7 @@ public CollectionShaperExpression( Projection = projection; InnerShaper = innerShaper; Navigation = navigation; - ElementType = elementType; + ElementType = elementType ?? navigation.ClrType.TryGetSequenceType(); } protected override Expression VisitChildren(ExpressionVisitor visitor) diff --git a/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs index a82c81a088a..951c4c7d59e 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs @@ -14,7 +14,7 @@ public AsyncGearsOfWarQueryInMemoryTest(GearsOfWarQueryInMemoryFixture fixture, { } - [ConditionalFact(Skip = "Issue#16963")] + [ConditionalFact(Skip = "Issue#16963 Group By")] public override Task GroupBy_Select_sum() { return base.GroupBy_Select_sum(); diff --git a/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs index d8e66946572..386bec964c3 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs @@ -14,76 +14,40 @@ public AsyncSimpleQueryInMemoryTest(NorthwindQueryInMemoryFixture null; - - [ConditionalFact(Skip = "See issue#13857")] - public override void DbQuery_query() - { - base.DbQuery_query(); - } - - [ConditionalFact(Skip = "See issue#13857")] - public override Task DbQuery_query_async() - { - return base.DbQuery_query_async(); - } - - [ConditionalFact(Skip = "See issue#13857")] - public override void DbQuery_query_first() - { - base.DbQuery_query_first(); - } - - [ConditionalFact(Skip = "See issue#13857")] - public override Task DbQuery_query_first_async() - { - return base.DbQuery_query_first_async(); - } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/FiltersInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/FiltersInMemoryTest.cs index 1460bf1ecad..2533808c696 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/FiltersInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/FiltersInMemoryTest.cs @@ -1,7 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -13,33 +12,5 @@ public FiltersInMemoryTest(NorthwindQueryInMemoryFixture