From a328263a5e38f91c808ba3e637c3433f9eeddcd5 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 8 Aug 2019 10:52:21 -0700 Subject: [PATCH] Query: Treat FromSql as query root in nav expansion Resolves #16326 Filed #17036 for clean up --- .../NavigationExpandingExpressionVisitor.cs | 220 ++++++++++-------- src/EFCore/Query/QueryableMethodProvider.cs | 4 - 2 files changed, 128 insertions(+), 96 deletions(-) diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index 0f435528c83..3238ac9b61d 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -144,14 +144,14 @@ protected override Expression VisitMember(MemberExpression memberExpression) protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { - if (methodCallExpression.Method.DeclaringType == typeof(Queryable) - || methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions) - || methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)) + var method = methodCallExpression.Method; + if (method.DeclaringType == typeof(Queryable) + || method.DeclaringType == typeof(QueryableExtensions) + || method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)) { var firstArgument = Visit(methodCallExpression.Arguments[0]); if (firstArgument is NavigationExpansionExpression source) { - var method = methodCallExpression.Method; var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; if (source.PendingOrderings.Any() @@ -161,7 +161,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ApplyPendingOrderings(source); } - switch (methodCallExpression.Method.Name) + switch (method.Name) { case nameof(Queryable.AsQueryable) when genericMethod == QueryableMethodProvider.AsQueryableMethodInfo: @@ -194,36 +194,43 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); case nameof(Queryable.Average) - when QueryableMethodProvider.IsAverageMethodInfo(method): + when QueryableMethodProvider.IsAverageWithoutSelectorMethodInfo(method): case nameof(Queryable.Sum) - when QueryableMethodProvider.IsSumMethodInfo(method): + when QueryableMethodProvider.IsSumWithoutSelectorMethodInfo(method): case nameof(Queryable.Max) - when genericMethod == QueryableMethodProvider.MaxWithoutSelectorMethodInfo - || genericMethod == QueryableMethodProvider.MaxWithSelectorMethodInfo: + when genericMethod == QueryableMethodProvider.MaxWithoutSelectorMethodInfo: case nameof(Queryable.Min) - when genericMethod == QueryableMethodProvider.MinWithoutSelectorMethodInfo - || genericMethod == QueryableMethodProvider.MinWithSelectorMethodInfo: + when genericMethod == QueryableMethodProvider.MinWithoutSelectorMethodInfo: return ProcessAverageMaxMinSum( source, - methodCallExpression.Method.IsGenericMethod - ? methodCallExpression.Method.GetGenericMethodDefinition() - : methodCallExpression.Method, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); + genericMethod ?? method, + null); + + case nameof(Queryable.Average) + when QueryableMethodProvider.IsAverageWithSelectorMethodInfo(method): + case nameof(Queryable.Sum) + when QueryableMethodProvider.IsSumWithSelectorMethodInfo(method): + case nameof(Queryable.Max) + when genericMethod == QueryableMethodProvider.MaxWithSelectorMethodInfo: + case nameof(Queryable.Min) + when genericMethod == QueryableMethodProvider.MinWithSelectorMethodInfo: + return ProcessAverageMaxMinSum( + source, + genericMethod ?? method, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); case nameof(Queryable.Distinct) when genericMethod == QueryableMethodProvider.DistinctMethodInfo: + return ProcessDistinctSkipTake(source, genericMethod, null); + case nameof(Queryable.Skip) when genericMethod == QueryableMethodProvider.SkipMethodInfo: case nameof(Queryable.Take) when genericMethod == QueryableMethodProvider.TakeMethodInfo: return ProcessDistinctSkipTake( source, - methodCallExpression.Method.GetGenericMethodDefinition(), - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1] - : null); + genericMethod, + methodCallExpression.Arguments[1]); case nameof(Queryable.Contains) when genericMethod == QueryableMethodProvider.ContainsMethodInfo: @@ -232,73 +239,88 @@ when QueryableMethodProvider.IsSumMethodInfo(method): methodCallExpression.Arguments[1]); case nameof(Queryable.First) - when genericMethod == QueryableMethodProvider.FirstWithoutPredicateMethodInfo - || genericMethod == QueryableMethodProvider.FirstWithPredicateMethodInfo: + when genericMethod == QueryableMethodProvider.FirstWithoutPredicateMethodInfo: case nameof(Queryable.FirstOrDefault) - when genericMethod == QueryableMethodProvider.FirstOrDefaultWithoutPredicateMethodInfo - || genericMethod == QueryableMethodProvider.FirstOrDefaultWithPredicateMethodInfo: + when genericMethod == QueryableMethodProvider.FirstOrDefaultWithoutPredicateMethodInfo: case nameof(Queryable.Single) - when genericMethod == QueryableMethodProvider.SingleWithoutPredicateMethodInfo - || genericMethod == QueryableMethodProvider.SingleWithPredicateMethodInfo: + when genericMethod == QueryableMethodProvider.SingleWithoutPredicateMethodInfo: case nameof(Queryable.SingleOrDefault) - when genericMethod == QueryableMethodProvider.SingleOrDefaultWithoutPredicateMethodInfo - || genericMethod == QueryableMethodProvider.SingleOrDefaultWithPredicateMethodInfo: + when genericMethod == QueryableMethodProvider.SingleOrDefaultWithoutPredicateMethodInfo: case nameof(Queryable.Last) - when genericMethod == QueryableMethodProvider.LastWithoutPredicateMethodInfo - || genericMethod == QueryableMethodProvider.LastWithPredicateMethodInfo: + when genericMethod == QueryableMethodProvider.LastWithoutPredicateMethodInfo: case nameof(Queryable.LastOrDefault) - when genericMethod == QueryableMethodProvider.LastOrDefaultWithoutPredicateMethodInfo - || genericMethod == QueryableMethodProvider.LastOrDefaultWithPredicateMethodInfo: + when genericMethod == QueryableMethodProvider.LastOrDefaultWithoutPredicateMethodInfo: return ProcessFirstSingleLastOrDefault( source, - methodCallExpression.Method.GetGenericMethodDefinition(), - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, + genericMethod, + null, + methodCallExpression.Type); + + case nameof(Queryable.First) + when genericMethod == QueryableMethodProvider.FirstWithPredicateMethodInfo: + case nameof(Queryable.FirstOrDefault) + when genericMethod == QueryableMethodProvider.FirstOrDefaultWithPredicateMethodInfo: + case nameof(Queryable.Single) + when genericMethod == QueryableMethodProvider.SingleWithPredicateMethodInfo: + case nameof(Queryable.SingleOrDefault) + when genericMethod == QueryableMethodProvider.SingleOrDefaultWithPredicateMethodInfo: + case nameof(Queryable.Last) + when genericMethod == QueryableMethodProvider.LastWithPredicateMethodInfo: + case nameof(Queryable.LastOrDefault) + when genericMethod == QueryableMethodProvider.LastOrDefaultWithPredicateMethodInfo: + return ProcessFirstSingleLastOrDefault( + source, + genericMethod, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), methodCallExpression.Type); case nameof(Queryable.Join) when genericMethod == QueryableMethodProvider.JoinMethodInfo: - { - var secondArgument = Visit(methodCallExpression.Arguments[1]); - if (secondArgument is NavigationExpansionExpression innerSource) { - return ProcessJoin( - source, - innerSource, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + var secondArgument = Visit(methodCallExpression.Arguments[1]); + if (secondArgument is NavigationExpansionExpression innerSource) + { + return ProcessJoin( + source, + innerSource, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + } + break; } - break; - } case nameof(QueryableExtensions.LeftJoin) when genericMethod == QueryableExtensions.LeftJoinMethodInfo: - { - var secondArgument = Visit(methodCallExpression.Arguments[1]); - if (secondArgument is NavigationExpansionExpression innerSource) { - return ProcessLeftJoin( - source, - innerSource, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + var secondArgument = Visit(methodCallExpression.Arguments[1]); + if (secondArgument is NavigationExpansionExpression innerSource) + { + return ProcessLeftJoin( + source, + innerSource, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + } + break; } - break; - } case nameof(Queryable.SelectMany) - when genericMethod == QueryableMethodProvider.SelectManyWithoutCollectionSelectorMethodInfo - || genericMethod == QueryableMethodProvider.SelectManyWithCollectionSelectorMethodInfo: + when genericMethod == QueryableMethodProvider.SelectManyWithoutCollectionSelectorMethodInfo: return ProcessSelectMany( source, - methodCallExpression.Method.GetGenericMethodDefinition(), + genericMethod, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments.Count == 3 - ? methodCallExpression.Arguments[2].UnwrapLambdaFromQuote() - : null); + null); + + case nameof(Queryable.SelectMany) + when genericMethod == QueryableMethodProvider.SelectManyWithCollectionSelectorMethodInfo: + return ProcessSelectMany( + source, + genericMethod, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote()); case nameof(Queryable.Concat) when genericMethod == QueryableMethodProvider.ConcatMethodInfo: @@ -308,17 +330,17 @@ when QueryableMethodProvider.IsSumMethodInfo(method): when genericMethod == QueryableMethodProvider.IntersectMethodInfo: case nameof(Queryable.Union) when genericMethod == QueryableMethodProvider.UnionMethodInfo: - { - var secondArgument = Visit(methodCallExpression.Arguments[1]); - if (secondArgument is NavigationExpansionExpression innerSource) { - return ProcessSetOperation( - source, - methodCallExpression.Method.GetGenericMethodDefinition(), - innerSource); + var secondArgument = Visit(methodCallExpression.Arguments[1]); + if (secondArgument is NavigationExpansionExpression innerSource) + { + return ProcessSetOperation( + source, + genericMethod, + innerSource); + } + break; } - break; - } case nameof(Queryable.Cast) when genericMethod == QueryableMethodProvider.CastMethodInfo: @@ -326,7 +348,7 @@ when QueryableMethodProvider.IsSumMethodInfo(method): when genericMethod == QueryableMethodProvider.OfTypeMethodInfo: return ProcessCastOfType( source, - methodCallExpression.Method.GetGenericMethodDefinition(), + genericMethod, methodCallExpression.Type.TryGetSequenceType()); case nameof(EntityFrameworkQueryableExtensions.Include): @@ -334,7 +356,7 @@ when QueryableMethodProvider.IsSumMethodInfo(method): return ProcessInclude( source, methodCallExpression.Arguments[1], - string.Equals(methodCallExpression.Method.Name, + string.Equals(method.Name, nameof(EntityFrameworkQueryableExtensions.ThenInclude))); case nameof(Queryable.GroupBy) @@ -375,7 +397,7 @@ when QueryableMethodProvider.IsSumMethodInfo(method): when genericMethod == QueryableMethodProvider.OrderByDescendingMethodInfo: return ProcessOrderByThenBy( source, - methodCallExpression.Method.GetGenericMethodDefinition(), + genericMethod, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false); @@ -385,7 +407,7 @@ when QueryableMethodProvider.IsSumMethodInfo(method): when genericMethod == QueryableMethodProvider.ThenByDescendingMethodInfo: return ProcessOrderByThenBy( source, - methodCallExpression.Method.GetGenericMethodDefinition(), + genericMethod, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true); @@ -402,27 +424,23 @@ when QueryableMethodProvider.IsSumMethodInfo(method): methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); default: - throw new NotImplementedException("Unhandled method in navigation expansion:" + - $"{methodCallExpression.Method.Name}"); + throw new NotImplementedException($"Unhandled method in navigation expansion: {method.Name}"); } } else if (firstArgument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression - && methodCallExpression.Method.Name == nameof(Queryable.AsQueryable)) + && method.Name == nameof(Queryable.AsQueryable)) { var subquery = materializeCollectionNavigationExpression.Subquery; - if (subquery is OwnedNavigationReference ownedNavigationReference - && ownedNavigationReference.Navigation.IsCollection()) - { - return Visit(Expression.Call( + return subquery is OwnedNavigationReference ownedNavigationReference + && ownedNavigationReference.Navigation.IsCollection() + ? Visit(Expression.Call( QueryableMethodProvider.AsQueryableMethodInfo.MakeGenericMethod(subquery.Type.TryGetSequenceType()), - subquery)); - } - - return subquery; + subquery)) + : subquery; } else if (firstArgument is OwnedNavigationReference ownedNavigationReference && ownedNavigationReference.Navigation.IsCollection() - && methodCallExpression.Method.Name == nameof(Queryable.AsQueryable)) + && method.Name == nameof(Queryable.AsQueryable)) { var parameterName = GetParameterName("o"); var entityReference = ownedNavigationReference.EntityReference; @@ -434,8 +452,8 @@ when QueryableMethodProvider.IsSumMethodInfo(method): throw new NotImplementedException("NonNavSource"); } - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == _enumerableToListMethodInfo) + if (method.IsGenericMethod + && method.GetGenericMethodDefinition() == _enumerableToListMethodInfo) { var argument = Visit(methodCallExpression.Arguments[0]); if (argument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression) @@ -446,6 +464,24 @@ when QueryableMethodProvider.IsSumMethodInfo(method): return methodCallExpression.Update(null, new[] { argument }); } + if (method.IsGenericMethod + && method.Name == "FromSqlOnQueryable" + && methodCallExpression.Arguments.Count == 3 + && methodCallExpression.Arguments[0] is ConstantExpression constantExpression + && methodCallExpression.Arguments[1] is ConstantExpression + && methodCallExpression.Arguments[2] is ParameterExpression + && constantExpression.IsEntityQueryable()) + { + var source = (NavigationExpansionExpression)Visit(constantExpression); + + source.UpdateSource( + methodCallExpression.Update( + null, + new[] { source.Source, methodCallExpression.Arguments[1], methodCallExpression.Arguments[2] })); + + return source; + } + return ProcessUnknownMethod(methodCallExpression); } diff --git a/src/EFCore/Query/QueryableMethodProvider.cs b/src/EFCore/Query/QueryableMethodProvider.cs index 3fc326ef115..81d2a1990ce 100644 --- a/src/EFCore/Query/QueryableMethodProvider.cs +++ b/src/EFCore/Query/QueryableMethodProvider.cs @@ -90,16 +90,12 @@ public static bool IsSumWithoutSelectorMethodInfo(MethodInfo methodInfo) public static bool IsSumWithSelectorMethodInfo(MethodInfo methodInfo) => methodInfo.IsGenericMethod && SumWithSelectorMethodInfos.Values.Contains(methodInfo.GetGenericMethodDefinition()); - public static bool IsSumMethodInfo(MethodInfo methodInfo) - => IsSumWithoutSelectorMethodInfo(methodInfo) || IsSumWithSelectorMethodInfo(methodInfo); public static bool IsAverageWithoutSelectorMethodInfo(MethodInfo methodInfo) => AverageWithoutSelectorMethodInfos.Values.Contains(methodInfo); public static bool IsAverageWithSelectorMethodInfo(MethodInfo methodInfo) => methodInfo.IsGenericMethod && AverageWithSelectorMethodInfos.Values.Contains(methodInfo.GetGenericMethodDefinition()); - public static bool IsAverageMethodInfo(MethodInfo methodInfo) - => IsAverageWithoutSelectorMethodInfo(methodInfo) || IsAverageWithSelectorMethodInfo(methodInfo); static QueryableMethodProvider() {