From 7bb6d10c1e7ec8bc7c2cece56f5c61be386bcf51 Mon Sep 17 00:00:00 2001 From: Maurycy Markowski Date: Tue, 4 Jun 2019 17:19:05 -0700 Subject: [PATCH] Fix to #15763 - "collection selector was not NavigationExpansionExpression" when joining two FromSql()s Problem was that navigation expansion wasn't recognizing FromSqlOnQueryable as something that needs to be visited, so the expression was not being converted to NavigationExpansionExpression. SelectMany requires collection navigation to be a NavigationExpansionExpression, hence the error. Fix is to add generic handling of simple queryable methods. --- .../Internal/ExpressionExtensions.cs | 71 +++++++++++++++++- .../ExpressionExtensions.cs | 75 ------------------- .../NavigationExpansionHelpers.cs | 2 +- .../NavigationExpansionRootExpression.cs | 1 + .../CollectionNavigationRewritingVisitor.cs | 1 + .../NavigationExpandingVisitor_MethodCall.cs | 57 ++++++++------ .../NavigationExpansionCleanupVisitor.cs | 1 + .../NavigationExpansionReducingVisitor.cs | 1 + .../NavigationPropertyUnbindingVisitor.cs | 1 + ...ntityEqualityRewritingExpressionVisitor.cs | 11 +-- .../Query/AsyncFromSqlQueryTestBase.cs | 6 +- .../Query/FromSqlQueryTestBase.cs | 8 +- .../Query/AsNoTrackingTestBase.cs | 11 +++ 13 files changed, 131 insertions(+), 115 deletions(-) delete mode 100644 src/EFCore/Query/NavigationExpansion/ExpressionExtensions.cs diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index 84979a6261e..8d9902764af 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -10,9 +10,7 @@ using System.Runtime.Versioning; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; -using Microsoft.EntityFrameworkCore.Extensions.Internal; using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.Expressions.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; @@ -363,5 +361,74 @@ private static readonly Type _assignBinaryExpressionType private static readonly MethodInfo _fieldInfoSetValueMethod = typeof(FieldInfo).GetRuntimeMethod(nameof(FieldInfo.SetValue), new[] { typeof(object), typeof(object) }); + + public static LambdaExpression GetLambdaOrNull(this Expression expression) + => expression is LambdaExpression lambda + ? lambda + : expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote + ? (LambdaExpression)unary.Operand + : null; + + public static LambdaExpression UnwrapQuote(this Expression expression) + => expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote + ? (LambdaExpression)unary.Operand + : (LambdaExpression)expression; + + public static bool IsIncludeMethod(this MethodCallExpression methodCallExpression) + => methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + && methodCallExpression.Method.Name == nameof(EntityFrameworkQueryableExtensions.Include); + + public static Expression BuildPropertyAccess(this Expression root, List path) + { + var result = root; + foreach (var pathElement in path) + { + result = Expression.PropertyOrField(result, pathElement); + } + + return result; + } + + public static Expression CombineAndRemap( + Expression source, + ParameterExpression sourceParameter, + Expression replaceWith) + => new ExpressionCombiningVisitor(sourceParameter, replaceWith).Visit(source); + + public class ExpressionCombiningVisitor : ExpressionVisitor + { + private ParameterExpression _sourceParameter; + private Expression _replaceWith; + + public ExpressionCombiningVisitor( + ParameterExpression sourceParameter, + Expression replaceWith) + { + _sourceParameter = sourceParameter; + _replaceWith = replaceWith; + } + + protected override Expression VisitParameter(ParameterExpression parameterExpression) + => parameterExpression == _sourceParameter + ? _replaceWith + : base.VisitParameter(parameterExpression); + + protected override Expression VisitMember(MemberExpression memberExpression) + { + var newSource = Visit(memberExpression.Expression); + if (newSource is NewExpression newExpression) + { + var matchingMemberIndex = newExpression.Members.Select((m, i) => new { index = i, match = m == memberExpression.Member }).Where(r => r.match).SingleOrDefault()?.index; + if (matchingMemberIndex.HasValue) + { + return newExpression.Arguments[matchingMemberIndex.Value]; + } + } + + return newSource != memberExpression.Expression + ? memberExpression.Update(newSource) + : memberExpression; + } + } } } diff --git a/src/EFCore/Query/NavigationExpansion/ExpressionExtensions.cs b/src/EFCore/Query/NavigationExpansion/ExpressionExtensions.cs deleted file mode 100644 index bc4e154a0c2..00000000000 --- a/src/EFCore/Query/NavigationExpansion/ExpressionExtensions.cs +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; - -namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion -{ - public static class ExpressionExtensions - { - public static LambdaExpression UnwrapQuote(this Expression expression) - => expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote - ? (LambdaExpression)unary.Operand - : (LambdaExpression)expression; - - public static bool IsIncludeMethod(this MethodCallExpression methodCallExpression) - => methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) - && methodCallExpression.Method.Name == nameof(EntityFrameworkQueryableExtensions.Include); - - public static Expression BuildPropertyAccess(this Expression root, List path) - { - var result = root; - foreach (var pathElement in path) - { - result = Expression.PropertyOrField(result, pathElement); - } - - return result; - } - - public static Expression CombineAndRemap( - Expression source, - ParameterExpression sourceParameter, - Expression replaceWith) - => new ExpressionCombiningVisitor(sourceParameter, replaceWith).Visit(source); - - public class ExpressionCombiningVisitor : ExpressionVisitor - { - private ParameterExpression _sourceParameter; - private Expression _replaceWith; - - public ExpressionCombiningVisitor( - ParameterExpression sourceParameter, - Expression replaceWith) - { - _sourceParameter = sourceParameter; - _replaceWith = replaceWith; - } - - protected override Expression VisitParameter(ParameterExpression parameterExpression) - => parameterExpression == _sourceParameter - ? _replaceWith - : base.VisitParameter(parameterExpression); - - protected override Expression VisitMember(MemberExpression memberExpression) - { - var newSource = Visit(memberExpression.Expression); - if (newSource is NewExpression newExpression) - { - var matchingMemberIndex = newExpression.Members.Select((m, i) => new { index = i, match = m == memberExpression.Member }).Where(r => r.match).SingleOrDefault()?.index; - if (matchingMemberIndex.HasValue) - { - return newExpression.Arguments[matchingMemberIndex.Value]; - } - } - - return newSource != memberExpression.Expression - ? memberExpression.Update(newSource) - : memberExpression; - } - } - } - -} diff --git a/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs b/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs index 67e21288a7a..ea9d6d99c08 100644 --- a/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs +++ b/src/EFCore/Query/NavigationExpansion/NavigationExpansionHelpers.cs @@ -6,8 +6,8 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; -using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Extensions.Internal; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.Expressions.Internal; diff --git a/src/EFCore/Query/NavigationExpansion/NavigationExpansionRootExpression.cs b/src/EFCore/Query/NavigationExpansion/NavigationExpansionRootExpression.cs index 47a3bb12ed4..22311512a96 100644 --- a/src/EFCore/Query/NavigationExpansion/NavigationExpansionRootExpression.cs +++ b/src/EFCore/Query/NavigationExpansion/NavigationExpansionRootExpression.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query.Expressions.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/CollectionNavigationRewritingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/CollectionNavigationRewritingVisitor.cs index 0810077ebf2..9f3f0d236cb 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/CollectionNavigationRewritingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/CollectionNavigationRewritingVisitor.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Extensions.Internal; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.Internal; diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs index f6f070fa20b..5920c215cc1 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs @@ -78,10 +78,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.DefaultIfEmpty): return ProcessDefaultIfEmpty(methodCallExpression); - case "AsTracking": - case "AsNoTracking": - return ProcessBasicTerminatingOperation(methodCallExpression); - case nameof(Queryable.First): case nameof(Queryable.FirstOrDefault): case nameof(Queryable.Single): @@ -99,18 +95,48 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case "ThenInclude": return ProcessInclude(methodCallExpression); - //TODO: should we have relational version of this? - probably - case "FromSqlRaw": - return ProcessFromRawSql(methodCallExpression); - case nameof(EntityFrameworkQueryableExtensions.TagWith): return ProcessWithTag(methodCallExpression); default: - return base.VisitMethodCall(methodCallExpression); + return ProcessUnknownMethod(methodCallExpression); + } + } + + private Expression ProcessUnknownMethod(MethodCallExpression methodCallExpression) + { + var resultSequenceType = TryGetNonPrimitiveSequenceType(methodCallExpression.Type); + + // result is a sequence, no lambda arguments, exactly one generic argument corresponding to result sequence type + if (methodCallExpression.Object == null + && resultSequenceType != null + && methodCallExpression.Arguments.All(a => a.GetLambdaOrNull() == null) + && methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericArguments().Length == 1 + && methodCallExpression.Method.GetGenericArguments()[0] == resultSequenceType) + { + var argumentSequenceTypes = methodCallExpression.Arguments.Select(a => TryGetNonPrimitiveSequenceType(a.Type)).ToList(); + if (argumentSequenceTypes.FirstOrDefault() == resultSequenceType + && argumentSequenceTypes.Count(t => t != null) == 1) + { + var source = VisitSourceExpression(methodCallExpression.Arguments[0]); + var preProcessResult = PreProcessTerminatingOperation(source); + var newArguments = methodCallExpression.Arguments.Skip(1).Select(Visit).ToList(); + newArguments.Insert(0, preProcessResult.source); + + var methodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(preProcessResult.state.CurrentParameter.Type); + var rewritten = Expression.Call(methodInfo, newArguments); + + return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); + } } + + return base.VisitMethodCall(methodCallExpression); } + private Type TryGetNonPrimitiveSequenceType(Type type) + => type == typeof(string) || type.IsArray ? null : type.TryGetSequenceType(); + private NavigationExpansionExpression VisitSourceExpression(Expression sourceExpression) { var result = Visit(sourceExpression); @@ -845,17 +871,6 @@ private Expression ProcessSkipTake(MethodCallExpression methodCallExpression) return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); } - private Expression ProcessBasicTerminatingOperation(MethodCallExpression methodCallExpression) - { - var source = VisitSourceExpression(methodCallExpression.Arguments[0]); - var preProcessResult = PreProcessTerminatingOperation(source); - var newArguments = methodCallExpression.Arguments.Skip(1).ToList(); - newArguments.Insert(0, preProcessResult.source); - var rewritten = methodCallExpression.Update(methodCallExpression.Object, newArguments); - - return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); - } - private (Expression source, NavigationExpansionExpressionState state) PreProcessTerminatingOperation(NavigationExpansionExpression source) { var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State); @@ -1172,7 +1187,7 @@ private Expression ProcessCardinalityReducingOperation(MethodCallExpression meth return new NavigationExpansionExpression(applyOrderingsResult.source, applyOrderingsResult.state, methodCallExpression.Type); } - private Expression ProcessFromRawSql(MethodCallExpression methodCallExpression) + private Expression ProcessFromSql(MethodCallExpression methodCallExpression) { var source = VisitSourceExpression(methodCallExpression.Arguments[0]); diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionCleanupVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionCleanupVisitor.cs index 8cb510090af..2baa69c62c1 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionCleanupVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionCleanupVisitor.cs @@ -4,6 +4,7 @@ using System; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Extensions.Internal; +using Microsoft.EntityFrameworkCore.Internal; namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors { diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs index 4bb95e6265c..c0f601d449d 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyUnbindingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyUnbindingVisitor.cs index ddce8dd1e3d..e479f3679df 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyUnbindingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationPropertyUnbindingVisitor.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Internal; namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors { diff --git a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs index 4620a386d92..2867ba5c283 100644 --- a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs @@ -183,7 +183,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // Methods with a typed first argument (source), and with no lambda arguments or a single lambda // argument that has one parameter are rewritten automatically (e.g. Where(), FromSql(), Average() var newArguments = new Expression[arguments.Count]; - var lambdaArgs = arguments.Select(GetLambdaOrNull).Where(l => l != null).ToArray(); + var lambdaArgs = arguments.Select(a => a.GetLambdaOrNull()).Where(l => l != null).ToArray(); newSource = Visit(arguments[0]); newArguments[0] = Unwrap(newSource); if (methodCallExpression.Object == null @@ -194,7 +194,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp for (var i = 1; i < arguments.Count; i++) { // Visit all arguments, rewriting the single lambda to replace its parameter expression - newArguments[i] = GetLambdaOrNull(arguments[i]) is LambdaExpression lambda + newArguments[i] = arguments[i].GetLambdaOrNull() is LambdaExpression lambda ? Unwrap(RewriteAndVisitLambda(lambda, newSourceWrapper)) : Unwrap(Visit(arguments[i])); } @@ -593,13 +593,6 @@ protected static Expression UnwrapLastNavigation(Expression expression) ? methodCallExpression.Arguments[0] : null); - protected static LambdaExpression GetLambdaOrNull(Expression expression) - => expression is LambdaExpression lambda - ? lambda - : expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote - ? (LambdaExpression)unary.Operand - : null; - protected static Expression Unwrap(Expression expression) => expression switch { EntityReferenceExpression wrapper => wrapper.Underlying, diff --git a/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs index 8a692e41b48..8d14b175364 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/AsyncFromSqlQueryTestBase.cs @@ -81,7 +81,7 @@ public virtual async Task FromSqlRaw_queryable_composed() } } - [Fact(Skip = "#15763")] + [Fact] public virtual async Task FromSqlRaw_queryable_multiple_composed() { using (var context = CreateContext()) @@ -101,7 +101,7 @@ from o in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT } } - [Fact(Skip = "Issue#15763")] + [Fact] public virtual async Task FromSqlRaw_queryable_multiple_composed_with_closure_parameters() { var startDate = new DateTime(1997, 1, 1); @@ -125,7 +125,7 @@ from o in context.Set().FromSqlRaw( } } - [Fact(Skip = "Issue#15763")] + [Fact] public virtual async Task FromSqlRaw_queryable_multiple_composed_with_parameters_and_closure_parameters() { var city = "London"; diff --git a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs index 64cf4276663..4517beff761 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs @@ -305,7 +305,7 @@ var actual } } - [Fact(Skip = "#15763")] + [Fact] public virtual void FromSqlRaw_queryable_multiple_composed() { using (var context = CreateContext()) @@ -325,7 +325,7 @@ from o in context.Set().FromSqlRaw(NormalizeDelimetersInRawString("SELECT } } - [Fact(Skip = "Issue#15763")] + [Fact] public virtual void FromSqlRaw_queryable_multiple_composed_with_closure_parameters() { var startDate = new DateTime(1997, 1, 1); @@ -351,7 +351,7 @@ from o in context.Set().FromSqlRaw( } } - [Fact(Skip = "Issue#15763")] + [Fact] public virtual void FromSqlRaw_queryable_multiple_composed_with_parameters_and_closure_parameters() { var city = "London"; @@ -503,7 +503,7 @@ public virtual void FromSqlInterpolated_queryable_with_parameters_inline_interpo } } - [Fact(Skip = "Issue#15763")] + [Fact] public virtual void FromSqlInterpolated_queryable_multiple_composed_with_parameters_and_closure_parameters_interpolated() { var city = "London"; diff --git a/test/EFCore.Specification.Tests/Query/AsNoTrackingTestBase.cs b/test/EFCore.Specification.Tests/Query/AsNoTrackingTestBase.cs index 7649dec827a..82fe071f1d0 100644 --- a/test/EFCore.Specification.Tests/Query/AsNoTrackingTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/AsNoTrackingTestBase.cs @@ -152,6 +152,17 @@ var orders } } + [ConditionalFact] + public virtual void Applied_after_navigation_expansion() + { + using (var context = CreateContext()) + { + var orders = context.Set().Where(o => o.Customer.City != "London").AsNoTracking().ToList(); + + Assert.Equal(784, orders.Count); + } + } + [ConditionalFact] public virtual void Where_simple_shadow() {