diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 1fb84dbd69a..be3564af6b2 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -8,6 +8,7 @@ using System.Reflection; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; @@ -161,13 +162,38 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ return false; } - private Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) + private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) { return entityProjectionExpression.BindProperty(property); } + private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return groupByShaperExpression.ElementSelector; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod) + { + return methodCallExpression; + } + // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { @@ -179,6 +205,52 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp throw new InvalidOperationException("EF.Property called with wrong property name."); } + // GroupBy Aggregate case + if (methodCallExpression.Object == null + && methodCallExpression.Method.DeclaringType == typeof(Enumerable) + && methodCallExpression.Arguments.Count > 0 + && methodCallExpression.Arguments[0] is InMemoryGroupByShaperExpression groupByShaperExpression) + { + switch (methodCallExpression.Method.Name) + { + case nameof(Enumerable.Average): + case nameof(Enumerable.Max): + case nameof(Enumerable.Min): + case nameof(Enumerable.Sum): + var translation = Translate(GetSelector(methodCallExpression, groupByShaperExpression)); + var selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter); + MethodInfo getMethod() + => methodCallExpression.Method.Name switch + { + nameof(Enumerable.Average) => InMemoryLinqOperatorProvider.GetAverageWithSelector(selector.ReturnType), + nameof(Enumerable.Max) => InMemoryLinqOperatorProvider.GetMaxWithSelector(selector.ReturnType), + nameof(Enumerable.Min) => InMemoryLinqOperatorProvider.GetMinWithSelector(selector.ReturnType), + nameof(Enumerable.Sum) => InMemoryLinqOperatorProvider.GetSumWithSelector(selector.ReturnType), + _ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."), + }; + var method = getMethod(); + method = method.GetGenericArguments().Length == 2 + ? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) + : method.MakeGenericMethod(typeof(ValueBuffer)); + + return Expression.Call(method, + groupByShaperExpression.GroupingParameter, + selector); + + case nameof(Enumerable.Count): + return Expression.Call( + InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)), + groupByShaperExpression.GroupingParameter); + case nameof(Enumerable.LongCount): + return Expression.Call( + InMemoryLinqOperatorProvider.LongCountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)), + groupByShaperExpression.GroupingParameter); + + default: + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + } + // Subquery case var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); if (subqueryTranslation != null) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryGroupByShaperExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryGroupByShaperExpression.cs new file mode 100644 index 00000000000..4dfefd23622 --- /dev/null +++ b/src/EFCore.InMemory/Query/Internal/InMemoryGroupByShaperExpression.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Query; + +namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal +{ + public class InMemoryGroupByShaperExpression : GroupByShaperExpression + { + public InMemoryGroupByShaperExpression( + Expression keySelector, + Expression elementSelector, + ParameterExpression groupingParameter, + ParameterExpression valueBufferParameter) + : base(keySelector, elementSelector) + { + GroupingParameter = groupingParameter; + ValueBufferParameter = valueBufferParameter; + } + + public virtual ParameterExpression GroupingParameter { get; } + public virtual ParameterExpression ValueBufferParameter { get; } + } +} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs b/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs index 0d588543b4a..3d7f53aee47 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs @@ -11,73 +11,358 @@ namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { public static class InMemoryLinqOperatorProvider { - private static MethodInfo GetMethod(string name, int parameterCount = 0) - => GetMethods(name, parameterCount).Single(); - - private static IEnumerable GetMethods(string name, int parameterCount = 0) - => typeof(Enumerable).GetTypeInfo().GetDeclaredMethods(name) - .Where(mi => mi.GetParameters().Length == parameterCount + 1); - - public static MethodInfo Where = GetMethods(nameof(Enumerable.Where), 1) - .Single(mi => mi.GetParameters()[1].ParameterType.GetGenericArguments().Length == 2); - public static MethodInfo Select = GetMethods(nameof(Enumerable.Select), 1) - .Single(mi => mi.GetParameters()[1].ParameterType.GetGenericArguments().Length == 2); - - public static MethodInfo Join = GetMethod(nameof(Enumerable.Join), 4); - public static MethodInfo GroupJoin = GetMethod(nameof(Enumerable.GroupJoin), 4); - public static MethodInfo DefaultIfEmptyWithArg = GetMethod(nameof(Enumerable.DefaultIfEmpty), 1); - public static MethodInfo SelectMany = GetMethods(nameof(Enumerable.SelectMany), 2) - .Single(mi => mi.GetParameters()[1].ParameterType.GetGenericArguments().Length == 2); - public static MethodInfo Contains = GetMethod(nameof(Enumerable.Contains), 1); - - public static MethodInfo OrderBy = GetMethod(nameof(Enumerable.OrderBy), 1); - public static MethodInfo OrderByDescending = GetMethod(nameof(Enumerable.OrderByDescending), 1); - public static MethodInfo ThenBy = GetMethod(nameof(Enumerable.ThenBy), 1); - public static MethodInfo ThenByDescending = GetMethod(nameof(Enumerable.ThenByDescending), 1); - public static MethodInfo All = GetMethod(nameof(Enumerable.All), 1); - public static MethodInfo Any = GetMethod(nameof(Enumerable.Any)); - public static MethodInfo AnyPredicate = GetMethod(nameof(Enumerable.Any), 1); - public static MethodInfo Count = GetMethod(nameof(Enumerable.Count)); - public static MethodInfo LongCount = GetMethod(nameof(Enumerable.LongCount)); - public static MethodInfo CountPredicate = GetMethod(nameof(Enumerable.Count), 1); - public static MethodInfo LongCountPredicate = GetMethod(nameof(Enumerable.LongCount), 1); - public static MethodInfo Distinct = GetMethod(nameof(Enumerable.Distinct)); - public static MethodInfo Take = GetMethod(nameof(Enumerable.Take), 1); - public static MethodInfo Skip = GetMethod(nameof(Enumerable.Skip), 1); - - public static MethodInfo FirstPredicate = GetMethod(nameof(Enumerable.First), 1); - public static MethodInfo FirstOrDefaultPredicate = GetMethod(nameof(Enumerable.FirstOrDefault), 1); - public static MethodInfo LastPredicate = GetMethod(nameof(Enumerable.Last), 1); - public static MethodInfo LastOrDefaultPredicate = GetMethod(nameof(Enumerable.LastOrDefault), 1); - public static MethodInfo SinglePredicate = GetMethod(nameof(Enumerable.Single), 1); - public static MethodInfo SingleOrDefaultPredicate = GetMethod(nameof(Enumerable.SingleOrDefault), 1); - - public static MethodInfo First = GetMethod(nameof(Enumerable.First), 0); - public static MethodInfo FirstOrDefault = GetMethod(nameof(Enumerable.FirstOrDefault), 0); - public static MethodInfo Last = GetMethod(nameof(Enumerable.Last), 0); - public static MethodInfo LastOrDefault = GetMethod(nameof(Enumerable.LastOrDefault), 0); - public static MethodInfo Single = GetMethod(nameof(Enumerable.Single), 0); - public static MethodInfo SingleOrDefault = GetMethod(nameof(Enumerable.SingleOrDefault), 0); - - public static MethodInfo Concat = GetMethod(nameof(Enumerable.Concat), 1); - public static MethodInfo Except = GetMethod(nameof(Enumerable.Except), 1); - public static MethodInfo Intersect = GetMethod(nameof(Enumerable.Intersect), 1); - public static MethodInfo Union = GetMethod(nameof(Enumerable.Union), 1); - - public static MethodInfo GetAggregateMethod(string methodName, Type elementType, int parameterCount = 0) + public static MethodInfo AsEnumerable { get; } + public static MethodInfo Cast { get; } + public static MethodInfo OfType { get; } + + public static MethodInfo All { get; } + public static MethodInfo AnyWithoutPredicate { get; } + public static MethodInfo AnyWithPredicate { get; } + public static MethodInfo Contains { get; } + + public static MethodInfo Concat { get; } + public static MethodInfo Except { get; } + public static MethodInfo Intersect { get; } + public static MethodInfo Union { get; } + + public static MethodInfo CountWithoutPredicate { get; } + public static MethodInfo CountWithPredicate { get; } + public static MethodInfo LongCountWithoutPredicate { get; } + public static MethodInfo LongCountWithPredicate { get; } + public static MethodInfo MinWithSelector { get; } + public static MethodInfo MinWithoutSelector { get; } + public static MethodInfo MaxWithSelector { get; } + public static MethodInfo MaxWithoutSelector { get; } + + public static MethodInfo ElementAt { get; } + public static MethodInfo ElementAtOrDefault { get; } + public static MethodInfo FirstWithoutPredicate { get; } + public static MethodInfo FirstWithPredicate { get; } + public static MethodInfo FirstOrDefaultWithoutPredicate { get; } + public static MethodInfo FirstOrDefaultWithPredicate { get; } + public static MethodInfo SingleWithoutPredicate { get; } + public static MethodInfo SingleWithPredicate { get; } + public static MethodInfo SingleOrDefaultWithoutPredicate { get; } + public static MethodInfo SingleOrDefaultWithPredicate { get; } + public static MethodInfo LastWithoutPredicate { get; } + public static MethodInfo LastWithPredicate { get; } + public static MethodInfo LastOrDefaultWithoutPredicate { get; } + public static MethodInfo LastOrDefaultWithPredicate { get; } + + public static MethodInfo Distinct { get; } + public static MethodInfo Reverse { get; } + public static MethodInfo Where { get; } + public static MethodInfo Select { get; } + public static MethodInfo Skip { get; } + public static MethodInfo Take { get; } + public static MethodInfo SkipWhile { get; } + public static MethodInfo TakeWhile { get; } + public static MethodInfo OrderBy { get; } + public static MethodInfo OrderByDescending { get; } + public static MethodInfo ThenBy { get; } + public static MethodInfo ThenByDescending { get; } + public static MethodInfo DefaultIfEmptyWithoutArgument { get; } + public static MethodInfo DefaultIfEmptyWithArgument { get; } + + public static MethodInfo Join { get; } + public static MethodInfo GroupJoin { get; } + public static MethodInfo SelectManyWithCollectionSelector { get; } + public static MethodInfo SelectManyWithoutCollectionSelector { get; } + + public static MethodInfo GroupByWithKeySelector { get; } + public static MethodInfo GroupByWithKeyElementSelector { get; } + public static MethodInfo GroupByWithKeyElementResultSelector { get; } + public static MethodInfo GroupByWithKeyResultSelector { get; } + + public static MethodInfo GetAverageWithoutSelector(Type type) => AverageWithoutSelectorMethods[type]; + public static MethodInfo GetAverageWithSelector(Type type) => AverageWithSelectorMethods[type]; + public static MethodInfo GetMaxWithoutSelector(Type type) + => MaxWithoutSelectorMethods.TryGetValue(type, out var method) + ? method + : MaxWithoutSelector; + + public static MethodInfo GetMaxWithSelector(Type type) + => MaxWithSelectorMethods.TryGetValue(type, out var method) + ? method + : MaxWithSelector; + + public static MethodInfo GetMinWithoutSelector(Type type) + => MinWithoutSelectorMethods.TryGetValue(type, out var method) + ? method + : MinWithoutSelector; + + public static MethodInfo GetMinWithSelector(Type type) + => MinWithSelectorMethods.TryGetValue(type, out var method) + ? method + : MinWithSelector; + + public static MethodInfo GetSumWithoutSelector(Type type) => SumWithoutSelectorMethods[type]; + public static MethodInfo GetSumWithSelector(Type type) => SumWithSelectorMethods[type]; + + private static Dictionary AverageWithoutSelectorMethods { get; } + private static Dictionary AverageWithSelectorMethods { get; } + private static Dictionary MaxWithoutSelectorMethods { get; } + private static Dictionary MaxWithSelectorMethods { get; } + private static Dictionary MinWithoutSelectorMethods { get; } + private static Dictionary MinWithSelectorMethods { get; } + private static Dictionary SumWithoutSelectorMethods { get; } + private static Dictionary SumWithSelectorMethods { get; } + + private static bool IsFunc(Type type, int funcGenericArgs = 2) + => type.IsGenericType + && type.GetGenericArguments().Length == funcGenericArgs; + + static InMemoryLinqOperatorProvider() { - Check.NotEmpty(methodName, nameof(methodName)); - Check.NotNull(elementType, nameof(elementType)); + var enumerableMethods = typeof(Enumerable).GetTypeInfo() + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly).ToList(); + + AsEnumerable = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.AsEnumerable) && mi.IsGenericMethod && mi.GetParameters().Length == 1); + Cast = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Cast) && mi.GetParameters().Length == 1); + OfType = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.OfType) && mi.GetParameters().Length == 1); + + All = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.All) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + AnyWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1); + AnyWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + Contains = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Contains) && mi.GetParameters().Length == 2); - var aggregateMethods = GetMethods(methodName, parameterCount).ToList(); + Concat = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Concat) && mi.GetParameters().Length == 2); + Except = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Except) && mi.GetParameters().Length == 2); + Intersect = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Intersect) && mi.GetParameters().Length == 2); + Union = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Union) && mi.GetParameters().Length == 2); - return - aggregateMethods - .Single( - mi => mi.GetParameters().Last().ParameterType.GetGenericArguments().Last() == elementType); - //?? aggregateMethods.Single(mi => mi.IsGenericMethod) - // .MakeGenericMethod(elementType); + CountWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Count) && mi.GetParameters().Length == 1); + CountWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Count) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + LongCountWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.LongCount) && mi.GetParameters().Length == 1); + LongCountWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.LongCount) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + MinWithSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Min) && mi.IsGenericMethod && mi.GetGenericArguments().Length == 2 && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + MinWithoutSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Min) && mi.IsGenericMethod && mi.GetGenericArguments().Length == 1 && mi.GetParameters().Length == 1); + MaxWithSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Max) && mi.IsGenericMethod && mi.GetGenericArguments().Length == 2 && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + MaxWithoutSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Max) && mi.IsGenericMethod && mi.GetGenericArguments().Length == 1 && mi.GetParameters().Length == 1); + + ElementAt = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.ElementAt) && mi.GetParameters().Length == 2); + ElementAtOrDefault = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.ElementAtOrDefault) && mi.GetParameters().Length == 2); + FirstWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.First) && mi.GetParameters().Length == 1); + FirstWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.First) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + FirstOrDefaultWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.FirstOrDefault) && mi.GetParameters().Length == 1); + FirstOrDefaultWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.FirstOrDefault) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + SingleWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Single) && mi.GetParameters().Length == 1); + SingleWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Single) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + SingleOrDefaultWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.SingleOrDefault) && mi.GetParameters().Length == 1); + SingleOrDefaultWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.SingleOrDefault) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + LastWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Last) && mi.GetParameters().Length == 1); + LastWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Last) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + LastOrDefaultWithoutPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.LastOrDefault) && mi.GetParameters().Length == 1); + LastOrDefaultWithPredicate = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.LastOrDefault) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + + Distinct = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Distinct) && mi.GetParameters().Length == 1); + Reverse = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Reverse) && mi.GetParameters().Length == 1); + Where = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Where) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + Select = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Select) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + Skip = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Skip) && mi.GetParameters().Length == 2); + Take = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Take) && mi.GetParameters().Length == 2); + SkipWhile = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.SkipWhile) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + TakeWhile = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.TakeWhile) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + OrderBy = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.OrderBy) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + OrderByDescending = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.OrderByDescending) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + ThenBy = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.ThenBy) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + ThenByDescending = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.ThenByDescending) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + DefaultIfEmptyWithoutArgument = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.DefaultIfEmpty) && mi.GetParameters().Length == 1); + DefaultIfEmptyWithArgument = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.DefaultIfEmpty) && mi.GetParameters().Length == 2); + + Join = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.Join) && mi.GetParameters().Length == 5); + GroupJoin = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.GroupJoin) && mi.GetParameters().Length == 5); + SelectManyWithCollectionSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.SelectMany) && mi.GetParameters().Length == 3 && IsFunc(mi.GetParameters()[1].ParameterType)); + SelectManyWithoutCollectionSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.SelectMany) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + + GroupByWithKeySelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.GroupBy) && mi.GetParameters().Length == 2 && IsFunc(mi.GetParameters()[1].ParameterType)); + GroupByWithKeyElementSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.GroupBy) && mi.GetParameters().Length == 3 && IsFunc(mi.GetParameters()[1].ParameterType) && IsFunc(mi.GetParameters()[2].ParameterType)); + GroupByWithKeyElementResultSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.GroupBy) && mi.GetParameters().Length == 4 && IsFunc(mi.GetParameters()[1].ParameterType) && IsFunc(mi.GetParameters()[2].ParameterType) && IsFunc(mi.GetParameters()[3].ParameterType, 3)); + GroupByWithKeyResultSelector = enumerableMethods.Single( + mi => mi.Name == nameof(Enumerable.GroupBy) && mi.GetParameters().Length == 3 && IsFunc(mi.GetParameters()[1].ParameterType) && IsFunc(mi.GetParameters()[2].ParameterType, 3)); + + MethodInfo getSumOrAverageWithoutSelector(string methodName) + => enumerableMethods.Single( + mi => mi.Name == methodName + && mi.GetParameters().Length == 1 + && mi.GetParameters()[0].ParameterType.GetGenericArguments()[0] == typeof(T)); + + static bool hasSelector(Type type) + => type.IsGenericType + && type.GetGenericArguments().Length == 2 + && type.GetGenericArguments()[1] == typeof(T); + + MethodInfo getSumOrAverageWithSelector(string methodName) + => enumerableMethods.Single( + mi => mi.Name == methodName + && mi.GetParameters().Length == 2 + && hasSelector(mi.GetParameters()[1].ParameterType)); + + AverageWithoutSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(long), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(int), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(double), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(float), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(decimal?), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(long?), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(int?), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(double?), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(float?), getSumOrAverageWithoutSelector(nameof(Queryable.Average)) } + }; + + AverageWithSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(long), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(int), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(double), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(float), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(decimal?), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(long?), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(int?), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(double?), getSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(float?), getSumOrAverageWithSelector(nameof(Queryable.Average)) } + }; + + MaxWithoutSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(long), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(int), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(double), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(float), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(decimal?), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(long?), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(int?), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(double?), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) }, + { typeof(float?), getSumOrAverageWithoutSelector(nameof(Queryable.Max)) } + }; + + MaxWithSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(long), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(int), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(double), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(float), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(decimal?), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(long?), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(int?), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(double?), getSumOrAverageWithSelector(nameof(Queryable.Max)) }, + { typeof(float?), getSumOrAverageWithSelector(nameof(Queryable.Max)) } + }; + + MinWithoutSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(long), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(int), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(double), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(float), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(decimal?), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(long?), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(int?), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(double?), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) }, + { typeof(float?), getSumOrAverageWithoutSelector(nameof(Queryable.Min)) } + }; + + MinWithSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(long), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(int), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(double), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(float), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(decimal?), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(long?), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(int?), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(double?), getSumOrAverageWithSelector(nameof(Queryable.Min)) }, + { typeof(float?), getSumOrAverageWithSelector(nameof(Queryable.Min)) } + }; + + SumWithoutSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(long), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(int), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(double), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(float), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(decimal?), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(long?), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(int?), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(double?), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(float?), getSumOrAverageWithoutSelector(nameof(Queryable.Sum)) } + }; + + SumWithSelectorMethods = new Dictionary + { + { typeof(decimal), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(long), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(int), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(double), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(float), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(decimal?), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(long?), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(int?), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(double?), getSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(float?), getSumOrAverageWithSelector(nameof(Queryable.Sum)) } + }; } } - } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs index 6456db627cd..ded09cef614 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs @@ -250,7 +250,6 @@ protected override Expression VisitMemberInit(MemberInitExpression memberInitExp } newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]); - if (newBindings[i] == null) { return null; diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index da5f81b9c14..542d58dcc93 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -23,18 +23,20 @@ private static readonly PropertyInfo _valueBufferCountMemberInfo private readonly List _valueBufferSlots = new List(); private readonly IDictionary> _entityProjectionCache = new Dictionary>(); + private readonly ParameterExpression _valueBufferParameter; private IDictionary _projectionMapping = new Dictionary(); + private ParameterExpression _groupingParameter; public virtual IReadOnlyList Projection => _valueBufferSlots; public virtual Expression ServerQueryExpression { get; set; } - public virtual ParameterExpression ValueBufferParameter { get; } + public virtual ParameterExpression CurrentParameter => _groupingParameter ?? _valueBufferParameter; public override Type Type => typeof(IEnumerable); public sealed override ExpressionType NodeType => ExpressionType.Extension; public InMemoryQueryExpression(IEntityType entityType) { - ValueBufferParameter = Parameter(typeof(ValueBuffer), "valueBuffer"); + _valueBufferParameter = Parameter(typeof(ValueBuffer), "valueBuffer"); ServerQueryExpression = new InMemoryTableExpression(entityType); var readExpressionMap = new Dictionary(); foreach (var property in entityType.GetAllBaseTypesInclusive().SelectMany(et => et.GetDeclaredProperties())) @@ -47,7 +49,7 @@ public InMemoryQueryExpression(IEntityType entityType) readExpressionMap[property] = Condition( LessThan( Constant(property.GetIndex()), - MakeMemberAccess(ValueBufferParameter, + MakeMemberAccess(_valueBufferParameter, _valueBufferCountMemberInfo)), CreateReadValueExpression(typeof(object), property.GetIndex(), property), Default(typeof(object))); @@ -145,7 +147,7 @@ public virtual int AddSubqueryProjection(ShapedQueryExpression shapedQueryExpres innerShaper = new ShaperRemappingExpressionVisitor(subquery._projectionMapping) .Visit(shapedQueryExpression.ShaperExpression); - innerShaper = Lambda(innerShaper, subquery.ValueBufferParameter); + innerShaper = Lambda(innerShaper, subquery.CurrentParameter); return AddToProjection(serverQueryExpression); } @@ -225,12 +227,13 @@ public virtual void PushdownIntoSubquery() NewArrayInit( typeof(object), _valueBufferSlots - .Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e) - .ToArray())), - ValueBufferParameter); + .Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + CurrentParameter); + + _groupingParameter = null; ServerQueryExpression = Call( - InMemoryLinqOperatorProvider.Select.MakeGenericMethod(typeof(ValueBuffer), typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.Select.MakeGenericMethod(ServerQueryExpression.Type.TryGetSequenceType(), typeof(ValueBuffer)), ServerQueryExpression, selectorLambda); @@ -294,7 +297,7 @@ public virtual void ApplyProjection() _valueBufferSlots .Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e) .ToArray())), - ValueBufferParameter); + CurrentParameter); ServerQueryExpression = Call( InMemoryLinqOperatorProvider.Select.MakeGenericMethod(typeof(ValueBuffer), typeof(ValueBuffer)), @@ -302,16 +305,93 @@ public virtual void ApplyProjection() selectorLambda); } - private Expression CreateReadValueExpression( - Type type, - int index, - IPropertyBase property) + public virtual InMemoryGroupByShaperExpression ApplyGrouping(Expression groupingKey, Expression shaperExpression) + { + PushdownIntoSubquery(); + + var selectMethod = (MethodCallExpression)ServerQueryExpression; + var groupBySource = selectMethod.Arguments[0]; + var elementSelector = selectMethod.Arguments[1]; + _groupingParameter = Parameter(typeof(IGrouping), "grouping"); + var groupingKeyAccessExpression = PropertyOrField(_groupingParameter, nameof(IGrouping.Key)); + var groupingKeyExpressions = new List(); + groupingKey = GetGroupingKey(groupingKey, groupingKeyExpressions, groupingKeyAccessExpression); + var keySelector = Lambda( + New( + _valueBufferConstructor, + NewArrayInit( + typeof(object), + groupingKeyExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + _valueBufferParameter); + + ServerQueryExpression = Call( + InMemoryLinqOperatorProvider.GroupByWithKeyElementSelector.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), typeof(ValueBuffer)), + selectMethod.Arguments[0], + keySelector, + selectMethod.Arguments[1]); + + return new InMemoryGroupByShaperExpression( + groupingKey, + shaperExpression, + _groupingParameter, + _valueBufferParameter); + } + + private Expression GetGroupingKey(Expression key, List groupingExpressions, Expression groupingKeyAccessExpression) + { + switch (key) + { + case NewExpression newExpression: + var arguments = new Expression[newExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + arguments[i] = GetGroupingKey(newExpression.Arguments[i], groupingExpressions, groupingKeyAccessExpression); + } + return newExpression.Update(arguments); + + case MemberInitExpression memberInitExpression: + if (memberInitExpression.Bindings.Any(mb => !(mb is MemberAssignment))) + { + goto default; + } + + var updatedNewExpression = (NewExpression)GetGroupingKey( + memberInitExpression.NewExpression, groupingExpressions, groupingKeyAccessExpression); + var memberBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + for (var i = 0; i < memberBindings.Length; i++) + { + var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; + memberBindings[i] = memberAssignment.Update( + GetGroupingKey( + memberAssignment.Expression, + groupingExpressions, + groupingKeyAccessExpression)); + } + return memberInitExpression.Update(updatedNewExpression, memberBindings); + + default: + var index = groupingExpressions.Count; + groupingExpressions.Add(key); + return CreateReadValueExpression( + groupingKeyAccessExpression, + key.Type, + index, + InferPropertyFromInner(key)); + } + } + + private static Expression CreateReadValueExpression( + Expression valueBufferParameter, Type type, int index, IPropertyBase property) => Call( EntityMaterializerSource.TryReadValueMethod.MakeGenericMethod(type), - ValueBufferParameter, + valueBufferParameter, Constant(index), Constant(property, typeof(IPropertyBase))); + private Expression CreateReadValueExpression(Type type, int index, IPropertyBase property) + => CreateReadValueExpression(_valueBufferParameter, type, index, property); + public virtual void AddInnerJoin( InMemoryQueryExpression innerQueryExpression, LambdaExpression outerKeySelector, @@ -325,8 +405,8 @@ public virtual void AddInnerJoin( var replacingVisitor = new ReplacingExpressionVisitor( new Dictionary { - { ValueBufferParameter, outerParameter }, - { innerQueryExpression.ValueBufferParameter, innerParameter } + { CurrentParameter, outerParameter }, + { innerQueryExpression.CurrentParameter, innerParameter } }); var index = 0; @@ -438,8 +518,8 @@ public virtual void AddLeftJoin( var replacingVisitor = new ReplacingExpressionVisitor( new Dictionary { - { ValueBufferParameter, MakeMemberAccess(outerParameter, outerMemberInfo) }, - { innerQueryExpression.ValueBufferParameter, innerParameter } + { CurrentParameter, MakeMemberAccess(outerParameter, outerMemberInfo) }, + { innerQueryExpression.CurrentParameter, innerParameter } }); var index = 0; @@ -497,7 +577,7 @@ public virtual void AddLeftJoin( var collectionSelector = Lambda( Call( - InMemoryLinqOperatorProvider.DefaultIfEmptyWithArg.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.DefaultIfEmptyWithArgument.MakeGenericMethod(typeof(ValueBuffer)), collection, New( _valueBufferConstructor, @@ -518,7 +598,7 @@ public virtual void AddLeftJoin( innerParameter); ServerQueryExpression = Call( - InMemoryLinqOperatorProvider.SelectMany.MakeGenericMethod( + InMemoryLinqOperatorProvider.SelectManyWithCollectionSelector.MakeGenericMethod( groupTransparentIdentifierType, typeof(ValueBuffer), typeof(ValueBuffer)), groupJoinExpression, collectionSelector, @@ -536,8 +616,8 @@ public virtual void AddSelectMany(InMemoryQueryExpression innerQueryExpression, var replacingVisitor = new ReplacingExpressionVisitor( new Dictionary { - { ValueBufferParameter, outerParameter }, - { innerQueryExpression.ValueBufferParameter, innerParameter } + { CurrentParameter, outerParameter }, + { innerQueryExpression.CurrentParameter, innerParameter } }); var index = 0; @@ -608,10 +688,10 @@ public virtual void AddSelectMany(InMemoryQueryExpression innerQueryExpression, innerParameter); ServerQueryExpression = Call( - InMemoryLinqOperatorProvider.SelectMany.MakeGenericMethod( + InMemoryLinqOperatorProvider.SelectManyWithCollectionSelector.MakeGenericMethod( typeof(ValueBuffer), typeof(ValueBuffer), typeof(ValueBuffer)), ServerQueryExpression, - Lambda(innerQueryExpression.ServerQueryExpression, ValueBufferParameter), + Lambda(innerQueryExpression.ServerQueryExpression, CurrentParameter), resultSelector); _projectionMapping = projectionMapping; diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 66543828fb1..7329a3e2b16 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -72,7 +72,7 @@ protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression sour inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.All.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.All.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, predicate); @@ -88,7 +88,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour if (predicate == null) { inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Any.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.AnyWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression); } else @@ -100,7 +100,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour } inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.AnyPredicate.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.AnyWithPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, predicate); } @@ -141,10 +141,10 @@ protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression Expression.Call( InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(item.Type), Expression.Call( - InMemoryLinqOperatorProvider.Select.MakeGenericMethod(typeof(ValueBuffer), item.Type), + InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, item.Type), inMemoryQueryExpression.ServerQueryExpression, Expression.Lambda( - inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.ValueBufferParameter)), + inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter)), item); source.ShaperExpression = inMemoryQueryExpression.GetSingleScalarProjection(); @@ -160,7 +160,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so { inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Count.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression); } else @@ -173,7 +173,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.CountPredicate.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.CountWithPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, predicate); } @@ -193,7 +193,7 @@ protected override ShapedQueryExpression TranslateDistinct(ShapedQueryExpression inMemoryQueryExpression.PushdownIntoSubquery(); inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Distinct.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.Distinct.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression); return source; @@ -212,12 +212,111 @@ protected override ShapedQueryExpression TranslateFirstOrDefault(ShapedQueryExpr predicate, returnType, returnDefault - ? InMemoryLinqOperatorProvider.FirstOrDefault - : InMemoryLinqOperatorProvider.First); + ? InMemoryLinqOperatorProvider.FirstOrDefaultWithoutPredicate + : InMemoryLinqOperatorProvider.FirstWithoutPredicate); } protected override ShapedQueryExpression TranslateGroupBy(ShapedQueryExpression source, LambdaExpression keySelector, LambdaExpression elementSelector, LambdaExpression resultSelector) - => null; + { + var remappedKeySelector = RemapLambdaBody(source, keySelector); + + var translatedKey = TranslateGroupingKey(remappedKeySelector); + if (translatedKey != null) + { + if (elementSelector != null) + { + source = TranslateSelect(source, elementSelector); + } + + var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; + source.ShaperExpression = inMemoryQueryExpression.ApplyGrouping(translatedKey, source.ShaperExpression); + + if (resultSelector == null) + { + return source; + } + + var keyAccessExpression = Expression.MakeMemberAccess( + source.ShaperExpression, + source.ShaperExpression.Type.GetTypeInfo().GetMember(nameof(IGrouping.Key))[0]); + + var original1 = resultSelector.Parameters[0]; + var original2 = resultSelector.Parameters[1]; + + var newResultSelectorBody = new ReplacingExpressionVisitor( + new Dictionary { + { original1, keyAccessExpression }, + { original2, source.ShaperExpression } + }).Visit(resultSelector.Body); + + //newResultSelectorBody = ExpandWeakEntities(selectExpression, newResultSelectorBody); + + source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(inMemoryQueryExpression, newResultSelectorBody); + + inMemoryQueryExpression.PushdownIntoSubquery(); + + return source; + } + + return null; + } + + private Expression TranslateGroupingKey(Expression expression) + { + switch (expression) + { + case NewExpression newExpression: + if (newExpression.Arguments.Count == 0) + { + return newExpression; + } + + var newArguments = new Expression[newExpression.Arguments.Count]; + for (var i = 0; i < newArguments.Length; i++) + { + newArguments[i] = TranslateGroupingKey(newExpression.Arguments[i]); + if (newArguments[i] == null) + { + return null; + } + } + + return newExpression.Update(newArguments); + + case MemberInitExpression memberInitExpression: + var updatedNewExpression = (NewExpression)TranslateGroupingKey(memberInitExpression.NewExpression); + if (updatedNewExpression == null) + { + return null; + } + + var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + for (var i = 0; i < newBindings.Length; i++) + { + var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; + var visitedExpression = TranslateGroupingKey(memberAssignment.Expression); + if (visitedExpression == null) + { + return null; + } + + newBindings[i] = memberAssignment.Update(visitedExpression); + } + + return memberInitExpression.Update(updatedNewExpression, newBindings); + + default: + var translation = _expressionTranslator.Translate(expression); + if (translation == null) + { + return null; + } + + return translation.Type == expression.Type + ? (Expression)translation + : Expression.Convert(translation, expression.Type); + } + } protected override ShapedQueryExpression TranslateGroupJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) => null; @@ -258,8 +357,8 @@ protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpre predicate, returnType, returnDefault - ? InMemoryLinqOperatorProvider.LastOrDefault - : InMemoryLinqOperatorProvider.Last); + ? InMemoryLinqOperatorProvider.LastOrDefaultWithoutPredicate + : InMemoryLinqOperatorProvider.LastWithoutPredicate); } protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) @@ -296,7 +395,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio { inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.LongCount.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.LongCountWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression); } else @@ -309,7 +408,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.LongCountPredicate.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.LongCountWithPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, predicate); } @@ -412,7 +511,7 @@ protected override ShapedQueryExpression TranslateOrderBy(ShapedQueryExpression inMemoryQueryExpression.ServerQueryExpression = Expression.Call( (ascending ? InMemoryLinqOperatorProvider.OrderBy : InMemoryLinqOperatorProvider.OrderByDescending) - .MakeGenericMethod(typeof(ValueBuffer), keySelector.ReturnType), + .MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, keySelector.ReturnType), inMemoryQueryExpression.ServerQueryExpression, keySelector); @@ -432,8 +531,15 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s var newSelectorBody = ReplacingExpressionVisitor.Replace( selector.Parameters.Single(), source.ShaperExpression, selector.Body); - source.ShaperExpression = _projectionBindingExpressionVisitor - .Translate((InMemoryQueryExpression)source.QueryExpression, newSelectorBody); + var groupByQuery = source.ShaperExpression is GroupByShaperExpression; + var queryExpression = (InMemoryQueryExpression)source.QueryExpression; + + source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(queryExpression, newSelectorBody); + + if (groupByQuery) + { + queryExpression.PushdownIntoSubquery(); + } return source; } @@ -513,8 +619,8 @@ protected override ShapedQueryExpression TranslateSingleOrDefault(ShapedQueryExp predicate, returnType, returnDefault - ? InMemoryLinqOperatorProvider.SingleOrDefault - : InMemoryLinqOperatorProvider.Single); + ? InMemoryLinqOperatorProvider.SingleOrDefaultWithoutPredicate + : InMemoryLinqOperatorProvider.SingleWithoutPredicate); } protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression source, Expression count) @@ -528,7 +634,7 @@ protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression sou inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Skip.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.Skip.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, count); @@ -552,7 +658,7 @@ protected override ShapedQueryExpression TranslateTake(ShapedQueryExpression sou inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Take.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.Take.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, count); @@ -574,7 +680,7 @@ protected override ShapedQueryExpression TranslateThenBy(ShapedQueryExpression s inMemoryQueryExpression.ServerQueryExpression = Expression.Call( (ascending ? InMemoryLinqOperatorProvider.ThenBy : InMemoryLinqOperatorProvider.ThenByDescending) - .MakeGenericMethod(typeof(ValueBuffer), keySelector.ReturnType), + .MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, keySelector.ReturnType), inMemoryQueryExpression.ServerQueryExpression, keySelector); @@ -594,7 +700,7 @@ protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression so } inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Where.MakeGenericMethod(typeof(ValueBuffer)), + InMemoryLinqOperatorProvider.Where.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression, predicate); @@ -613,7 +719,7 @@ private LambdaExpression TranslateLambdaExpression( return lambdaBody != null ? Expression.Lambda(lambdaBody, - ((InMemoryQueryExpression)shapedQueryExpression.QueryExpression).ValueBufferParameter) + ((InMemoryQueryExpression)shapedQueryExpression.QueryExpression).CurrentParameter) : null; } @@ -632,7 +738,7 @@ private ShapedQueryExpression TranslateScalarAggregate( || selector.Body == selector.Parameters[0] ? Expression.Lambda( inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), - inMemoryQueryExpression.ValueBufferParameter) + inMemoryQueryExpression.CurrentParameter) : TranslateLambdaExpression(source, selector); if (selector == null) @@ -640,11 +746,22 @@ private ShapedQueryExpression TranslateScalarAggregate( return null; } + MethodInfo getMethod() + => methodName switch + { + nameof(Enumerable.Average) => InMemoryLinqOperatorProvider.GetAverageWithSelector(selector.ReturnType), + nameof(Enumerable.Max) => InMemoryLinqOperatorProvider.GetMaxWithSelector(selector.ReturnType), + nameof(Enumerable.Min) => InMemoryLinqOperatorProvider.GetMinWithSelector(selector.ReturnType), + nameof(Enumerable.Sum) => InMemoryLinqOperatorProvider.GetSumWithSelector(selector.ReturnType), + _ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."), + }; + var method = getMethod(); + method = method.GetGenericArguments().Length == 2 + ? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) + : method.MakeGenericMethod(typeof(ValueBuffer)); + inMemoryQueryExpression.ServerQueryExpression - = Expression.Call( - InMemoryLinqOperatorProvider - .GetAggregateMethod(methodName, selector.ReturnType, parameterCount: 1) - .MakeGenericMethod(typeof(ValueBuffer)), + = Expression.Call(method, inMemoryQueryExpression.ServerQueryExpression, selector); @@ -669,7 +786,7 @@ private ShapedQueryExpression TranslateSingleResultOperator( inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - method.MakeGenericMethod(typeof(ValueBuffer)), + method.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), inMemoryQueryExpression.ServerQueryExpression); inMemoryQueryExpression.ConvertToEnumerable(); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs index b8d80ac7b1f..c2f8d1962f0 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs @@ -32,7 +32,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) _materializationContextBindings[parameterExpression] = ((IDictionary)GetProjectionIndex(queryExpression, projectionBindingExpression), - ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression).ValueBufferParameter); + ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression).CurrentParameter); var updatedExpression = Expression.New( newExpression.Constructor, @@ -79,7 +79,7 @@ protected override Expression VisitExtension(Expression extensionExpression) { var queryExpression = (InMemoryQueryExpression)projectionBindingExpression.QueryExpression; var projectionIndex = (int)GetProjectionIndex(queryExpression, projectionBindingExpression); - var valueBuffer = queryExpression.ValueBufferParameter; + var valueBuffer = queryExpression.CurrentParameter; return Expression.Call( EntityMaterializerSource.TryReadValueMethod.MakeGenericMethod(projectionBindingExpression.Type), diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs index 1cdaa15668c..a1efacff407 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs @@ -50,7 +50,7 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s var inMemoryQueryExpression = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression; var shaper = new ShaperExpressionProcessingExpressionVisitor( - inMemoryQueryExpression, inMemoryQueryExpression.ValueBufferParameter) + inMemoryQueryExpression, inMemoryQueryExpression.CurrentParameter) .Inject(shapedQueryExpression.ShaperExpression); shaper = InjectEntityMaterializers(shaper); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs index 07f37fb029c..5318823b6e5 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs @@ -27,5 +27,4 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) return this; } } - } diff --git a/src/EFCore/Query/GroupByShaperExpression.cs b/src/EFCore/Query/GroupByShaperExpression.cs index 14629ebf5dd..40a27f7a4f9 100644 --- a/src/EFCore/Query/GroupByShaperExpression.cs +++ b/src/EFCore/Query/GroupByShaperExpression.cs @@ -23,13 +23,13 @@ public GroupByShaperExpression(Expression keySelector, Expression elementSelecto public virtual void Print(ExpressionPrinter expressionPrinter) { - expressionPrinter.AppendLine("GroupBy("); + expressionPrinter.AppendLine($"{nameof(GroupByShaperExpression)}:"); expressionPrinter.Append("KeySelector: "); expressionPrinter.Visit(KeySelector); expressionPrinter.AppendLine(", "); expressionPrinter.Append("ElementSelector:"); expressionPrinter.Visit(ElementSelector); - expressionPrinter.AppendLine(")"); + expressionPrinter.AppendLine(); } protected override Expression VisitChildren(ExpressionVisitor visitor) diff --git a/src/EFCore/Query/QueryableMethods.cs b/src/EFCore/Query/QueryableMethods.cs index 75e9e2bdefd..fb4be67a623 100644 --- a/src/EFCore/Query/QueryableMethods.cs +++ b/src/EFCore/Query/QueryableMethods.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; diff --git a/test/EFCore.InMemory.FunctionalTests/InMemoryComplianceTest.cs b/test/EFCore.InMemory.FunctionalTests/InMemoryComplianceTest.cs index c47cbf14784..a9de08cd218 100644 --- a/test/EFCore.InMemory.FunctionalTests/InMemoryComplianceTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/InMemoryComplianceTest.cs @@ -20,7 +20,6 @@ public class InMemoryComplianceTest : ComplianceTestBase // Remaining Issue #16963 3.0 query tests: typeof(ComplexNavigationsWeakQueryTestBase<>), typeof(OwnedQueryTestBase<>), - typeof(GroupByQueryTestBase<>), typeof(ComplexNavigationsQueryTestBase<>), typeof(GearsOfWarQueryTestBase<>), typeof(SpatialQueryTestBase<>) diff --git a/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs index 951c4c7d59e..564449e3c3f 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/AsyncGearsOfWarQueryInMemoryTest.cs @@ -1,8 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Threading.Tasks; -using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -13,11 +11,5 @@ public AsyncGearsOfWarQueryInMemoryTest(GearsOfWarQueryInMemoryFixture fixture, : base(fixture) { } - - [ConditionalFact(Skip = "Issue#16963 Group By")] - public override Task GroupBy_Select_sum() - { - return base.GroupBy_Select_sum(); - } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/GroupByQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/GroupByQueryInMemoryTest.cs index 666d6df2635..6224eb64770 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/GroupByQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/GroupByQueryInMemoryTest.cs @@ -8,7 +8,7 @@ namespace Microsoft.EntityFrameworkCore.Query { - internal class GroupByQueryInMemoryTest : GroupByQueryTestBase> + public class GroupByQueryInMemoryTest : GroupByQueryTestBase> { public GroupByQueryInMemoryTest( NorthwindQueryInMemoryFixture fixture, @@ -18,11 +18,5 @@ public GroupByQueryInMemoryTest( { //TestLoggerFactory.TestOutputHelper = testOutputHelper; } - - [ConditionalTheory(Skip = "See issue #9591")] - public override Task Select_Distinct_GroupBy(bool isAsync) - { - return base.Select_Distinct_GroupBy(isAsync); - } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs index 2a8ff48b6cc..b4e6554fc34 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs @@ -150,7 +150,7 @@ public class Motor #region Bug3595 - [ConditionalFact(Skip = "Issue#16963 groupBy")] + [ConditionalFact] public void GroupBy_with_uninitialized_datetime_projection_3595() { using (CreateScratch(Seed3595, "3595")) diff --git a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs index a0d76376ddc..3f97e8fda86 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs @@ -87,12 +87,6 @@ public override void Client_code_using_instance_method_throws() base.Client_code_using_instance_method_throws(); } - [ConditionalTheory(Skip = "Issue#16963 (GroupBy)")] - public override Task GroupBy_Select_Union(bool isAsync) - { - return Task.CompletedTask; - } - [ConditionalTheory(Skip = "Issue#17386")] public override Task Contains_with_local_tuple_array_closure(bool isAsync) { diff --git a/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs index 72b1dff0283..7410678ecb3 100644 --- a/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs @@ -982,9 +982,8 @@ public virtual void Cant_query_Min_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Min(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableDecimal)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Min( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableDecimal)"), + RemoveNewLines(ex.Message)); ex = Assert.Throws( () => query @@ -993,9 +992,8 @@ public virtual void Cant_query_Min_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Min>(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableDateTimeOffset)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Min>( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableDateTimeOffset)"), + RemoveNewLines(ex.Message)); ex = Assert.Throws( () => query @@ -1004,9 +1002,8 @@ public virtual void Cant_query_Min_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Min>(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableTimeSpan)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Min>( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableTimeSpan)"), + RemoveNewLines(ex.Message)); ex = Assert.Throws( () => query @@ -1015,9 +1012,8 @@ public virtual void Cant_query_Min_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Min>(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableUnsignedInt64)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Min>( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableUnsignedInt64)"), + RemoveNewLines(ex.Message)); } } @@ -1061,9 +1057,8 @@ public virtual void Cant_query_Max_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Max(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableDecimal)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Max( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableDecimal)"), + RemoveNewLines(ex.Message)); ex = Assert.Throws( () => query @@ -1072,9 +1067,8 @@ public virtual void Cant_query_Max_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Max>(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableDateTimeOffset)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Max>( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableDateTimeOffset)"), + RemoveNewLines(ex.Message)); ex = Assert.Throws( () => query @@ -1083,9 +1077,8 @@ public virtual void Cant_query_Max_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Max>(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableTimeSpan)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Max>( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableTimeSpan)"), + RemoveNewLines(ex.Message)); ex = Assert.Throws( () => query @@ -1094,9 +1087,8 @@ public virtual void Cant_query_Max_of_converted_types() Assert.Equal( CoreStrings.TranslationFailed( - "Max>(\n source: GroupBy(\n KeySelector: 1, \n ElementSelector:EntityShaperExpression: \n EntityType: BuiltInNullableDataTypes\n ValueBufferExpression: \n ProjectionBindingExpression: EmptyProjectionMember\n IsNullable: False\n )\n , \n selector: (e) => e.TestNullableUnsignedInt64)"), - ex.Message, - ignoreLineEndingDifferences: true); + "Max>( source: GroupByShaperExpression: KeySelector: 1, ElementSelector:EntityShaperExpression: EntityType: BuiltInNullableDataTypes ValueBufferExpression: ProjectionBindingExpression: EmptyProjectionMember IsNullable: False , selector: (e) => e.TestNullableUnsignedInt64)"), + RemoveNewLines(ex.Message)); } } diff --git a/test/EFCore.Tests/ChangeTracking/Internal/FixupTest.cs b/test/EFCore.Tests/ChangeTracking/Internal/FixupTest.cs index 28b9e8fed9c..ce5b4da7668 100644 --- a/test/EFCore.Tests/ChangeTracking/Internal/FixupTest.cs +++ b/test/EFCore.Tests/ChangeTracking/Internal/FixupTest.cs @@ -2418,7 +2418,7 @@ public void Replace_dependent_one_to_one_no_navs_FK_set(EntityState oldEntitySta } } - [ConditionalFact(Skip = "issue #16963 - using InMemory Include query")] // Issue #6067 + [ConditionalFact] // Issue #6067 public void Collection_nav_props_remain_fixed_up_after_manual_fixup_and_DetectChanges() { using (var context = new FixupContext())