From 18ae848b81da9d5ab8db76b98db07b78eb033c0c Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 11 Feb 2020 16:07:51 -0800 Subject: [PATCH] Query: Combine queryable method processing expression visitors Introduces QueryableMethodNormalizingExpressionVisitor which - Extract query metadata methods - Convert from enumerable to queryable - Convert List.Contains to queryable Contains - Flatten GroupJoin-SelectMany Nav expansion now calls this method on query filter/ defining query Resolves #19708 Part of #18923 --- ...ryableMethodConvertingExpressionVisitor.cs | 206 ------ .../GroupJoinFlatteningExpressionVisitor.cs | 367 ----------- .../NavigationExpandingExpressionVisitor.cs | 10 +- ...ueryMetadataExtractingExpressionVisitor.cs | 60 -- ...yableMethodNormalizingExpressionVisitor.cs | 616 ++++++++++++++++++ .../Query/QueryTranslationPreprocessor.cs | 12 +- .../Query/QueryBugsInMemoryTest.cs | 104 +++ .../Query/QueryBugsTest.cs | 148 +++++ 8 files changed, 881 insertions(+), 642 deletions(-) delete mode 100644 src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs delete mode 100644 src/EFCore/Query/Internal/GroupJoinFlatteningExpressionVisitor.cs delete mode 100644 src/EFCore/Query/Internal/QueryMetadataExtractingExpressionVisitor.cs create mode 100644 src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs diff --git a/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs b/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs deleted file mode 100644 index 001c34a676f..00000000000 --- a/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.Query.Internal -{ - public class EnumerableToQueryableMethodConvertingExpressionVisitor : ExpressionVisitor - { - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - if (methodCallExpression.Method.DeclaringType == typeof(Enumerable)) - { - if (methodCallExpression.Method.Name == nameof(Enumerable.SequenceEqual)) - { - // Skip SequenceEqual over enumerable since it could be over byte[] or other array properties - // Ideally we could make check in nav expansion about it (since it can bind to property) - // But since we don't translate SequenceEqual anyway, this is fine for now. - return base.VisitMethodCall(methodCallExpression); - } - - if (methodCallExpression.Arguments.Count > 0 - && ClientSource(methodCallExpression.Arguments[0])) - { - // this is methodCall over closure variable or constant - return base.VisitMethodCall(methodCallExpression); - } - - var arguments = VisitAndConvert(methodCallExpression.Arguments, nameof(VisitMethodCall)).ToArray(); - - var enumerableMethod = methodCallExpression.Method; - var enumerableParameters = enumerableMethod.GetParameters(); - Type[] genericTypeArguments = null; - if (enumerableMethod.Name == nameof(Enumerable.Min) - || enumerableMethod.Name == nameof(Enumerable.Max)) - { - genericTypeArguments = new Type[methodCallExpression.Arguments.Count]; - - if (!enumerableMethod.IsGenericMethod) - { - genericTypeArguments[0] = enumerableMethod.ReturnType; - } - else - { - var argumentTypes = enumerableMethod.GetGenericArguments(); - if (argumentTypes.Length == genericTypeArguments.Length) - { - genericTypeArguments = argumentTypes; - } - else - { - genericTypeArguments[0] = argumentTypes[0]; - genericTypeArguments[1] = enumerableMethod.ReturnType; - } - } - } - else if (enumerableMethod.IsGenericMethod) - { - genericTypeArguments = enumerableMethod.GetGenericArguments(); - } - - foreach (var method in typeof(Queryable).GetTypeInfo().GetDeclaredMethods(methodCallExpression.Method.Name)) - { - var queryableMethod = method; - if (queryableMethod.IsGenericMethod) - { - if (genericTypeArguments != null - && queryableMethod.GetGenericArguments().Length == genericTypeArguments.Length) - { - queryableMethod = queryableMethod.MakeGenericMethod(genericTypeArguments); - } - else - { - continue; - } - } - - var queryableParameters = queryableMethod.GetParameters(); - if (enumerableParameters.Length != queryableParameters.Length) - { - continue; - } - - var validMapping = true; - for (var i = 0; i < enumerableParameters.Length; i++) - { - var enumerableParameterType = enumerableParameters[i].ParameterType; - var queryableParameterType = queryableParameters[i].ParameterType; - - if (enumerableParameterType == queryableParameterType) - { - continue; - } - - if (CanConvertEnumerableToQueryable(enumerableParameterType, queryableParameterType)) - { - var innerArgument = arguments[i]; - var genericType = innerArgument.Type.TryGetSequenceType(); - - // If innerArgument has ToList applied to it then unwrap it. - // Also preserve generic argument of ToList is applied to different type - if (arguments[i].Type.TryGetElementType(typeof(List<>)) != null - && arguments[i] is MethodCallExpression toListMethodCallExpression - && toListMethodCallExpression.Method.IsGenericMethod - && toListMethodCallExpression.Method.GetGenericMethodDefinition() == EnumerableMethods.ToList) - { - genericType = toListMethodCallExpression.Method.GetGenericArguments()[0]; - innerArgument = toListMethodCallExpression.Arguments[0]; - } - - var innerQueryableElementType = innerArgument.Type.TryGetElementType(typeof(IQueryable<>)); - if (innerQueryableElementType == null - || innerQueryableElementType != genericType) - { - arguments[i] = Expression.Call( - QueryableMethods.AsQueryable.MakeGenericMethod(genericType), - innerArgument); - } - - continue; - } - - if (queryableParameterType.IsGenericType - && queryableParameterType.GetGenericTypeDefinition() == typeof(Expression<>) - && queryableParameterType.GetGenericArguments()[0] == enumerableParameterType) - { - continue; - } - - validMapping = false; - break; - } - - if (validMapping) - { - return Expression.Call( - queryableMethod, - arguments.Select( - arg => arg is LambdaExpression lambda ? Expression.Quote(lambda) : arg)); - } - } - - return methodCallExpression.Update(Visit(methodCallExpression.Object), arguments); - } - - if (methodCallExpression.Method.DeclaringType.IsGenericType - && methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(List<>) - && string.Equals(nameof(List.Contains), methodCallExpression.Method.Name)) - { - if (ClientSource(methodCallExpression.Object)) - { - // this is methodCall over closure variable or constant - return base.VisitMethodCall(methodCallExpression); - } - - var sourceType = methodCallExpression.Method.DeclaringType.GetGenericArguments()[0]; - - return Expression.Call( - QueryableMethods.Contains.MakeGenericMethod(sourceType), - Expression.Call( - QueryableMethods.AsQueryable.MakeGenericMethod(sourceType), - methodCallExpression.Object), - methodCallExpression.Arguments[0]); - } - - return base.VisitMethodCall(methodCallExpression); - } - - private static bool ClientSource(Expression expression) - => expression is ConstantExpression - || expression is MemberInitExpression - || expression is NewExpression - || (expression is ParameterExpression parameter - && parameter.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)); - - private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType) - { - if (enumerableType == typeof(IEnumerable) - && queryableType == typeof(IQueryable)) - { - return true; - } - - if (!enumerableType.IsGenericType - || !queryableType.IsGenericType - || !enumerableType.GetGenericArguments().SequenceEqual(queryableType.GetGenericArguments())) - { - return false; - } - - enumerableType = enumerableType.GetGenericTypeDefinition(); - queryableType = queryableType.GetGenericTypeDefinition(); - - return enumerableType == typeof(IEnumerable<>) && queryableType == typeof(IQueryable<>) - || enumerableType == typeof(IOrderedEnumerable<>) && queryableType == typeof(IOrderedQueryable<>); - } - } -} diff --git a/src/EFCore/Query/Internal/GroupJoinFlatteningExpressionVisitor.cs b/src/EFCore/Query/Internal/GroupJoinFlatteningExpressionVisitor.cs deleted file mode 100644 index 0f632ab3543..00000000000 --- a/src/EFCore/Query/Internal/GroupJoinFlatteningExpressionVisitor.cs +++ /dev/null @@ -1,367 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using Microsoft.EntityFrameworkCore.Internal; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.Query.Internal -{ - public class GroupJoinFlatteningExpressionVisitor : ExpressionVisitor - { - private static readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableReMappingExpressionVisitor - = new EnumerableToQueryableMethodConvertingExpressionVisitor(); - - private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor - = new SelectManyVerifyingExpressionVisitor(); - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - if (methodCallExpression.Method.DeclaringType == typeof(Queryable) - && methodCallExpression.Method.IsGenericMethod) - { - var genericMethod = methodCallExpression.Method.GetGenericMethodDefinition(); - if (genericMethod == QueryableMethods.SelectManyWithCollectionSelector) - { - // SelectMany - var selectManySource = methodCallExpression.Arguments[0]; - if (selectManySource is MethodCallExpression groupJoinMethod - && groupJoinMethod.Method.IsGenericMethod - && groupJoinMethod.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin) - { - // GroupJoin - var outer = Visit(groupJoinMethod.Arguments[0]); - var inner = Visit(groupJoinMethod.Arguments[1]); - var outerKeySelector = groupJoinMethod.Arguments[2].UnwrapLambdaFromQuote(); - var innerKeySelector = groupJoinMethod.Arguments[3].UnwrapLambdaFromQuote(); - var groupJoinResultSelector = groupJoinMethod.Arguments[4].UnwrapLambdaFromQuote(); - - var selectManyCollectionSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - var selectManyResultSelector = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(); - - var collectionSelectorBody = selectManyCollectionSelector.Body; - var defaultIfEmpty = false; - - if (collectionSelectorBody is MethodCallExpression collectionEndingMethod - && collectionEndingMethod.Method.IsGenericMethod - && collectionEndingMethod.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument) - { - defaultIfEmpty = true; - collectionSelectorBody = collectionEndingMethod.Arguments[0]; - } - - collectionSelectorBody = ReplacingExpressionVisitor.Replace( - selectManyCollectionSelector.Parameters[0], - groupJoinResultSelector.Body, - collectionSelectorBody); - - var correlatedCollectionSelector = _selectManyVerifyingExpressionVisitor - .VerifyCollectionSelector( - collectionSelectorBody, groupJoinResultSelector.Parameters[1]); - - if (correlatedCollectionSelector) - { - var outerParameter = outerKeySelector.Parameters[0]; - var innerParameter = innerKeySelector.Parameters[0]; - var correlationPredicate = Expression.Equal( - outerKeySelector.Body, - innerKeySelector.Body); - - inner = Expression.Call( - QueryableMethods.Where.MakeGenericMethod(inner.Type.TryGetSequenceType()), - inner, - Expression.Quote(Expression.Lambda(correlationPredicate, innerParameter))); - - inner = ReplacingExpressionVisitor.Replace( - groupJoinResultSelector.Parameters[1], - inner, - collectionSelectorBody); - - inner = Expression.Quote(Expression.Lambda(inner, outerParameter)); - } - else - { - inner = _enumerableToQueryableReMappingExpressionVisitor.Visit( - ReplacingExpressionVisitor.Replace( - groupJoinResultSelector.Parameters[1], - inner, - collectionSelectorBody)); - - if (inner is MethodCallExpression innerMethodCall - && innerMethodCall.Method.IsGenericMethod - && innerMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable - && innerMethodCall.Type == innerMethodCall.Arguments[0].Type) - { - // Remove redundant AsQueryable. - // It is fine to leave it in the tree since it is no-op - inner = innerMethodCall.Arguments[0]; - } - } - - var resultSelectorBody = ReplacingExpressionVisitor.Replace( - selectManyResultSelector.Parameters[0], - groupJoinResultSelector.Body, - selectManyResultSelector.Body); - - var resultSelector = Expression.Lambda( - resultSelectorBody, - groupJoinResultSelector.Parameters[0], - selectManyResultSelector.Parameters[1]); - - if (correlatedCollectionSelector) - { - // select many case - } - else - { - // join case - if (defaultIfEmpty) - { - // left join - return Expression.Call( - QueryableExtensions.LeftJoinMethodInfo.MakeGenericMethod( - outer.Type.TryGetSequenceType(), - inner.Type.TryGetSequenceType(), - outerKeySelector.ReturnType, - resultSelector.ReturnType), - outer, - inner, - outerKeySelector, - innerKeySelector, - resultSelector); - } - - // inner join - return Expression.Call( - QueryableMethods.Join.MakeGenericMethod( - outer.Type.TryGetSequenceType(), - inner.Type.TryGetSequenceType(), - outerKeySelector.ReturnType, - resultSelector.ReturnType), - outer, - inner, - outerKeySelector, - innerKeySelector, - resultSelector); - } - } - } - else if (genericMethod == QueryableMethods.SelectManyWithoutCollectionSelector) - { - // SelectMany - var selectManySource = methodCallExpression.Arguments[0]; - if (selectManySource is MethodCallExpression groupJoinMethod - && groupJoinMethod.Method.IsGenericMethod - && groupJoinMethod.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin) - { - // GroupJoin - var outer = Visit(groupJoinMethod.Arguments[0]); - var inner = Visit(groupJoinMethod.Arguments[1]); - var outerKeySelector = groupJoinMethod.Arguments[2].UnwrapLambdaFromQuote(); - var innerKeySelector = groupJoinMethod.Arguments[3].UnwrapLambdaFromQuote(); - var groupJoinResultSelector = groupJoinMethod.Arguments[4].UnwrapLambdaFromQuote(); - - var selectManyResultSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - - var groupJoinResultSelectorBody = groupJoinResultSelector.Body; - var defaultIfEmpty = false; - - if (groupJoinResultSelectorBody is MethodCallExpression collectionEndingMethod - && collectionEndingMethod.Method.IsGenericMethod - && collectionEndingMethod.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument) - { - defaultIfEmpty = true; - groupJoinResultSelectorBody = collectionEndingMethod.Arguments[0]; - } - - var correlatedCollectionSelector = _selectManyVerifyingExpressionVisitor - .VerifyCollectionSelector( - groupJoinResultSelectorBody, groupJoinResultSelector.Parameters[1]); - - if (!correlatedCollectionSelector) - { - inner = ReplacingExpressionVisitor.Replace( - groupJoinResultSelector.Parameters[1], - inner, - groupJoinResultSelectorBody); - - inner = ReplacingExpressionVisitor.Replace( - selectManyResultSelector.Parameters[0], - inner, - selectManyResultSelector.Body); - - inner = _enumerableToQueryableReMappingExpressionVisitor.Visit(inner); - - var resultSelector = Expression.Lambda( - innerKeySelector.Parameters[0], - groupJoinResultSelector.Parameters[0], - innerKeySelector.Parameters[0]); - - // join case - if (defaultIfEmpty) - { - // left join - return Expression.Call( - QueryableExtensions.LeftJoinMethodInfo.MakeGenericMethod( - outer.Type.TryGetSequenceType(), - inner.Type.TryGetSequenceType(), - outerKeySelector.ReturnType, - resultSelector.ReturnType), - outer, - inner, - outerKeySelector, - innerKeySelector, - resultSelector); - } - - // inner join - return Expression.Call( - QueryableMethods.Join.MakeGenericMethod( - outer.Type.TryGetSequenceType(), - inner.Type.TryGetSequenceType(), - outerKeySelector.ReturnType, - resultSelector.ReturnType), - outer, - inner, - outerKeySelector, - innerKeySelector, - resultSelector); - } - } - } - } - - return base.VisitMethodCall(methodCallExpression); - } - - private sealed class SelectManyVerifyingExpressionVisitor : ExpressionVisitor - { - private readonly List _allowedParameters = new List(); - private readonly ISet _allowedMethods = new HashSet { nameof(Queryable.Where), nameof(Queryable.AsQueryable) }; - - private ParameterExpression _rootParameter; - private int _rootParameterCount; - private bool _correlated; - - public bool VerifyCollectionSelector(Expression body, ParameterExpression rootParameter) - { - _correlated = false; - _rootParameterCount = 0; - _rootParameter = rootParameter; - - Visit(body); - - if (_rootParameterCount == 1) - { - var expression = body; - while (expression != null) - { - if (expression is MemberExpression memberExpression) - { - expression = memberExpression.Expression; - } - else if (expression is MethodCallExpression methodCallExpression - && methodCallExpression.Method.DeclaringType == typeof(Queryable)) - { - expression = methodCallExpression.Arguments[0]; - } - else if (expression is ParameterExpression) - { - if (expression != _rootParameter) - { - _correlated = true; - } - - break; - } - else - { - _correlated = true; - break; - } - } - } - - _rootParameter = null; - - return _correlated; - } - - protected override Expression VisitLambda(Expression lambdaExpression) - { - Check.NotNull(lambdaExpression, nameof(lambdaExpression)); - - try - { - _allowedParameters.AddRange(lambdaExpression.Parameters); - - return base.VisitLambda(lambdaExpression); - } - finally - { - foreach (var parameter in lambdaExpression.Parameters) - { - _allowedParameters.Remove(parameter); - } - } - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - if (_correlated) - { - return methodCallExpression; - } - - if (methodCallExpression.Method.DeclaringType == typeof(Queryable) - && !_allowedMethods.Contains(methodCallExpression.Method.Name)) - { - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.Select) - { - var selector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - if (selector.Body == selector.Parameters[0]) - { - // identity projection is allowed - return methodCallExpression; - } - } - - _correlated = true; - - return methodCallExpression; - } - - return base.VisitMethodCall(methodCallExpression); - } - - protected override Expression VisitParameter(ParameterExpression parameterExpression) - { - Check.NotNull(parameterExpression, nameof(parameterExpression)); - - if (_allowedParameters.Contains(parameterExpression)) - { - return parameterExpression; - } - - if (parameterExpression == _rootParameter) - { - _rootParameterCount++; - - return parameterExpression; - } - - _correlated = true; - - return base.VisitParameter(parameterExpression); - } - } - } -} diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index 872ad8f35ea..01332e3a0d6 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -32,14 +32,13 @@ private static readonly PropertyInfo _queryContextContextPropertyInfo { QueryableMethods.LastWithPredicate, QueryableMethods.LastWithoutPredicate }, { QueryableMethods.LastOrDefaultWithPredicate, QueryableMethods.LastOrDefaultWithoutPredicate } }; - + private readonly QueryTranslationPreprocessor _queryTranslationPreprocessor; private readonly QueryCompilationContext _queryCompilationContext; private readonly PendingSelectorExpandingExpressionVisitor _pendingSelectorExpandingExpressionVisitor; private readonly SubqueryMemberPushdownExpressionVisitor _subqueryMemberPushdownExpressionVisitor; private readonly ReducingExpressionVisitor _reducingExpressionVisitor; private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor; private readonly ISet _parameterNames = new HashSet(); - private readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableMethodConvertingExpressionVisitor; private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor; private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor; @@ -49,15 +48,16 @@ private readonly Dictionary _parameterizedQueryFi private readonly Parameters _parameters = new Parameters(); public NavigationExpandingExpressionVisitor( + [NotNull] QueryTranslationPreprocessor queryTranslationPreprocessor, [NotNull] QueryCompilationContext queryCompilationContext, [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter) { + _queryTranslationPreprocessor = queryTranslationPreprocessor; _queryCompilationContext = queryCompilationContext; _pendingSelectorExpandingExpressionVisitor = new PendingSelectorExpandingExpressionVisitor(this); _subqueryMemberPushdownExpressionVisitor = new SubqueryMemberPushdownExpressionVisitor(queryCompilationContext.Model); _reducingExpressionVisitor = new ReducingExpressionVisitor(); _entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor(); - _enumerableToQueryableMethodConvertingExpressionVisitor = new EnumerableToQueryableMethodConvertingExpressionVisitor(); _entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext); _parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor( evaluatableExpressionFilter, @@ -114,7 +114,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio if (definingQuery != null) { var processedDefiningQueryBody = _parameterExtractingExpressionVisitor.ExtractParameters(definingQuery.Body); - processedDefiningQueryBody = _enumerableToQueryableMethodConvertingExpressionVisitor.Visit(processedDefiningQueryBody); + processedDefiningQueryBody = _queryTranslationPreprocessor.NormalizeQueryableMethodCall(processedDefiningQueryBody); processedDefiningQueryBody = new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit(processedDefiningQueryBody); @@ -1161,7 +1161,7 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa { filterPredicate = queryFilter; filterPredicate = (LambdaExpression)_parameterExtractingExpressionVisitor.ExtractParameters(filterPredicate); - filterPredicate = (LambdaExpression)_enumerableToQueryableMethodConvertingExpressionVisitor.Visit(filterPredicate); + filterPredicate = (LambdaExpression)_queryTranslationPreprocessor.NormalizeQueryableMethodCall(filterPredicate); // We need to do entity equality, but that requires a full method call on a query root to properly flow the // entity information through. Construct a MethodCall wrapper for the predicate with the proper query root. diff --git a/src/EFCore/Query/Internal/QueryMetadataExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryMetadataExtractingExpressionVisitor.cs deleted file mode 100644 index d9d24ce01b5..00000000000 --- a/src/EFCore/Query/Internal/QueryMetadataExtractingExpressionVisitor.cs +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Linq.Expressions; -using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.Query.Internal -{ - public class QueryMetadataExtractingExpressionVisitor : ExpressionVisitor - { - private readonly QueryCompilationContext _queryCompilationContext; - - public QueryMetadataExtractingExpressionVisitor([NotNull] QueryCompilationContext queryCompilationContext) - { - _queryCompilationContext = queryCompilationContext; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - var method = methodCallExpression.Method; - if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) - && method.IsGenericMethod) - { - // We visit innerQueryable first so that we can get information in the same order operators are applied. - var genericMethodDefinition = method.GetGenericMethodDefinition(); - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.AsTrackingMethodInfo - || genericMethodDefinition == EntityFrameworkQueryableExtensions.AsNoTrackingMethodInfo) - { - var innerQueryable = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.IsTracking - = genericMethodDefinition == EntityFrameworkQueryableExtensions.AsTrackingMethodInfo; - - return innerQueryable; - } - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.TagWithMethodInfo) - { - var innerQueryable = Visit(methodCallExpression.Arguments[0]); - _queryCompilationContext.AddTag((string)((ConstantExpression)methodCallExpression.Arguments[1]).Value); - - return innerQueryable; - } - - if (genericMethodDefinition == EntityFrameworkQueryableExtensions.IgnoreQueryFiltersMethodInfo) - { - var innerQueryable = Visit(methodCallExpression.Arguments[0]); - - _queryCompilationContext.IgnoreQueryFilters = true; - - return innerQueryable; - } - } - - return base.VisitMethodCall(methodCallExpression); - } - } -} diff --git a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs new file mode 100644 index 00000000000..7eb57326d35 --- /dev/null +++ b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs @@ -0,0 +1,616 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class QueryableMethodNormalizingExpressionVisitor : ExpressionVisitor + { + private readonly QueryCompilationContext _queryCompilationContext; + private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor + = new SelectManyVerifyingExpressionVisitor(); + + public QueryableMethodNormalizingExpressionVisitor([NotNull] QueryCompilationContext queryCompilationContext) + { + Check.NotNull(queryCompilationContext, nameof(Query)); + + _queryCompilationContext = queryCompilationContext; + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + + var method = methodCallExpression.Method; + + // Extract information from query metadata method and prune them + if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + && method.IsGenericMethod + && ExtractQueryMetadata(methodCallExpression) is Expression expression) + { + return expression; + } + + Expression visitedExpression = null; + if (method.DeclaringType == typeof(Enumerable)) + { + visitedExpression = TryConvertEnumerableToQueryable(methodCallExpression); + } + + if (methodCallExpression.Method.DeclaringType.IsGenericType + && methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(List<>) + && string.Equals(nameof(List.Contains), methodCallExpression.Method.Name)) + { + visitedExpression = TryConvertListContainsToQueryableContains(methodCallExpression); + } + + visitedExpression ??= base.VisitMethodCall(methodCallExpression); + + if (visitedExpression is MethodCallExpression visitedMethodcall + && visitedMethodcall.Method.DeclaringType == typeof(Queryable) + && visitedMethodcall.Method.IsGenericMethod) + { + return TryFlattenGroupJoinSelectMany(visitedMethodcall); + } + + return visitedExpression; + } + + private Expression ExtractQueryMetadata(MethodCallExpression methodCallExpression) + { + // We visit innerQueryable first so that we can get information in the same order operators are applied. + var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition(); + if (genericMethodDefinition == EntityFrameworkQueryableExtensions.AsTrackingMethodInfo + || genericMethodDefinition == EntityFrameworkQueryableExtensions.AsNoTrackingMethodInfo) + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.IsTracking + = genericMethodDefinition == EntityFrameworkQueryableExtensions.AsTrackingMethodInfo; + + return visitedExpression; + } + + if (genericMethodDefinition == EntityFrameworkQueryableExtensions.TagWithMethodInfo) + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.AddTag((string)((ConstantExpression)methodCallExpression.Arguments[1]).Value); + + return visitedExpression; + } + + if (genericMethodDefinition == EntityFrameworkQueryableExtensions.IgnoreQueryFiltersMethodInfo) + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.IgnoreQueryFilters = true; + + return visitedExpression; + } + + return null; + } + + private Expression TryConvertEnumerableToQueryable(MethodCallExpression methodCallExpression) + { + // TODO : CHECK if this is still needed + if (methodCallExpression.Method.Name == nameof(Enumerable.SequenceEqual)) + { + // Skip SequenceEqual over enumerable since it could be over byte[] or other array properties + // Ideally we could make check in nav expansion about it (since it can bind to property) + // But since we don't translate SequenceEqual anyway, this is fine for now. + return base.VisitMethodCall(methodCallExpression); + } + + if (methodCallExpression.Arguments.Count > 0 + && ClientSource(methodCallExpression.Arguments[0])) + { + // this is methodCall over closure variable or constant + return base.VisitMethodCall(methodCallExpression); + } + + var arguments = VisitAndConvert(methodCallExpression.Arguments, nameof(VisitMethodCall)).ToArray(); + + var enumerableMethod = methodCallExpression.Method; + var enumerableParameters = enumerableMethod.GetParameters(); + Type[] genericTypeArguments = null; + if (enumerableMethod.Name == nameof(Enumerable.Min) + || enumerableMethod.Name == nameof(Enumerable.Max)) + { + genericTypeArguments = new Type[methodCallExpression.Arguments.Count]; + + if (!enumerableMethod.IsGenericMethod) + { + genericTypeArguments[0] = enumerableMethod.ReturnType; + } + else + { + var argumentTypes = enumerableMethod.GetGenericArguments(); + if (argumentTypes.Length == genericTypeArguments.Length) + { + genericTypeArguments = argumentTypes; + } + else + { + genericTypeArguments[0] = argumentTypes[0]; + genericTypeArguments[1] = enumerableMethod.ReturnType; + } + } + } + else if (enumerableMethod.IsGenericMethod) + { + genericTypeArguments = enumerableMethod.GetGenericArguments(); + } + + foreach (var method in typeof(Queryable).GetTypeInfo().GetDeclaredMethods(methodCallExpression.Method.Name)) + { + var queryableMethod = method; + if (queryableMethod.IsGenericMethod) + { + if (genericTypeArguments != null + && queryableMethod.GetGenericArguments().Length == genericTypeArguments.Length) + { + queryableMethod = queryableMethod.MakeGenericMethod(genericTypeArguments); + } + else + { + continue; + } + } + + var queryableParameters = queryableMethod.GetParameters(); + if (enumerableParameters.Length != queryableParameters.Length) + { + continue; + } + + var validMapping = true; + for (var i = 0; i < enumerableParameters.Length; i++) + { + var enumerableParameterType = enumerableParameters[i].ParameterType; + var queryableParameterType = queryableParameters[i].ParameterType; + + if (enumerableParameterType == queryableParameterType) + { + continue; + } + + if (CanConvertEnumerableToQueryable(enumerableParameterType, queryableParameterType)) + { + var innerArgument = arguments[i]; + var genericType = innerArgument.Type.TryGetSequenceType(); + + // If innerArgument has ToList applied to it then unwrap it. + // Also preserve generic argument of ToList is applied to different type + if (arguments[i].Type.TryGetElementType(typeof(List<>)) != null + && arguments[i] is MethodCallExpression toListMethodCallExpression + && toListMethodCallExpression.Method.IsGenericMethod + && toListMethodCallExpression.Method.GetGenericMethodDefinition() == EnumerableMethods.ToList) + { + genericType = toListMethodCallExpression.Method.GetGenericArguments()[0]; + innerArgument = toListMethodCallExpression.Arguments[0]; + } + + var innerQueryableElementType = innerArgument.Type.TryGetElementType(typeof(IQueryable<>)); + if (innerQueryableElementType == null + || innerQueryableElementType != genericType) + { + arguments[i] = Expression.Call( + QueryableMethods.AsQueryable.MakeGenericMethod(genericType), + innerArgument); + } + + continue; + } + + if (queryableParameterType.IsGenericType + && queryableParameterType.GetGenericTypeDefinition() == typeof(Expression<>) + && queryableParameterType.GetGenericArguments()[0] == enumerableParameterType) + { + continue; + } + + validMapping = false; + break; + } + + if (validMapping) + { + return Expression.Call( + queryableMethod, + arguments.Select( + arg => arg is LambdaExpression lambda ? Expression.Quote(lambda) : arg)); + } + } + + return methodCallExpression.Update(Visit(methodCallExpression.Object), arguments); + } + + private Expression TryConvertListContainsToQueryableContains(MethodCallExpression methodCallExpression) + { + if (ClientSource(methodCallExpression.Object)) + { + // this is methodCall over closure variable or constant + return base.VisitMethodCall(methodCallExpression); + } + + var sourceType = methodCallExpression.Method.DeclaringType.GetGenericArguments()[0]; + + return Expression.Call( + QueryableMethods.Contains.MakeGenericMethod(sourceType), + Expression.Call( + QueryableMethods.AsQueryable.MakeGenericMethod(sourceType), + methodCallExpression.Object), + methodCallExpression.Arguments[0]); + } + + private static bool ClientSource(Expression expression) + => expression is ConstantExpression + || expression is MemberInitExpression + || expression is NewExpression + || expression is ParameterExpression parameter + && parameter.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal); + + private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType) + { + if (enumerableType == typeof(IEnumerable) + && queryableType == typeof(IQueryable)) + { + return true; + } + + if (!enumerableType.IsGenericType + || !queryableType.IsGenericType + || !enumerableType.GetGenericArguments().SequenceEqual(queryableType.GetGenericArguments())) + { + return false; + } + + enumerableType = enumerableType.GetGenericTypeDefinition(); + queryableType = queryableType.GetGenericTypeDefinition(); + + return enumerableType == typeof(IEnumerable<>) && queryableType == typeof(IQueryable<>) + || enumerableType == typeof(IOrderedEnumerable<>) && queryableType == typeof(IOrderedQueryable<>); + } + + private Expression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCallExpression) + { + var genericMethod = methodCallExpression.Method.GetGenericMethodDefinition(); + if (genericMethod == QueryableMethods.SelectManyWithCollectionSelector) + { + // SelectMany + var selectManySource = methodCallExpression.Arguments[0]; + if (selectManySource is MethodCallExpression groupJoinMethod + && groupJoinMethod.Method.IsGenericMethod + && groupJoinMethod.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin) + { + // GroupJoin + var outer = groupJoinMethod.Arguments[0]; + var inner = groupJoinMethod.Arguments[1]; + var outerKeySelector = groupJoinMethod.Arguments[2].UnwrapLambdaFromQuote(); + var innerKeySelector = groupJoinMethod.Arguments[3].UnwrapLambdaFromQuote(); + var groupJoinResultSelector = groupJoinMethod.Arguments[4].UnwrapLambdaFromQuote(); + + var selectManyCollectionSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + var selectManyResultSelector = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(); + + var collectionSelectorBody = selectManyCollectionSelector.Body; + var defaultIfEmpty = false; + + if (collectionSelectorBody is MethodCallExpression collectionEndingMethod + && collectionEndingMethod.Method.IsGenericMethod + && collectionEndingMethod.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument) + { + defaultIfEmpty = true; + collectionSelectorBody = collectionEndingMethod.Arguments[0]; + } + + collectionSelectorBody = ReplacingExpressionVisitor.Replace( + selectManyCollectionSelector.Parameters[0], + groupJoinResultSelector.Body, + collectionSelectorBody); + + var correlatedCollectionSelector = _selectManyVerifyingExpressionVisitor + .VerifyCollectionSelector( + collectionSelectorBody, groupJoinResultSelector.Parameters[1]); + + if (correlatedCollectionSelector) + { + var outerParameter = outerKeySelector.Parameters[0]; + var innerParameter = innerKeySelector.Parameters[0]; + var correlationPredicate = Expression.Equal( + outerKeySelector.Body, + innerKeySelector.Body); + + inner = Expression.Call( + QueryableMethods.Where.MakeGenericMethod(inner.Type.TryGetSequenceType()), + inner, + Expression.Quote(Expression.Lambda(correlationPredicate, innerParameter))); + + inner = ReplacingExpressionVisitor.Replace( + groupJoinResultSelector.Parameters[1], + inner, + collectionSelectorBody); + + inner = Expression.Quote(Expression.Lambda(inner, outerParameter)); + } + else + { + inner = Visit(ReplacingExpressionVisitor.Replace( + groupJoinResultSelector.Parameters[1], inner, collectionSelectorBody)); + + if (inner is MethodCallExpression innerMethodCall + && innerMethodCall.Method.IsGenericMethod + && innerMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable + && innerMethodCall.Type == innerMethodCall.Arguments[0].Type) + { + // Remove redundant AsQueryable. + // It is fine to leave it in the tree since it is no-op + inner = innerMethodCall.Arguments[0]; + } + } + + var resultSelectorBody = ReplacingExpressionVisitor.Replace( + selectManyResultSelector.Parameters[0], + groupJoinResultSelector.Body, + selectManyResultSelector.Body); + + var resultSelector = Expression.Lambda( + resultSelectorBody, + groupJoinResultSelector.Parameters[0], + selectManyResultSelector.Parameters[1]); + + if (!correlatedCollectionSelector) + { + // join case + if (defaultIfEmpty) + { + // left join + return Expression.Call( + QueryableExtensions.LeftJoinMethodInfo.MakeGenericMethod( + outer.Type.TryGetSequenceType(), + inner.Type.TryGetSequenceType(), + outerKeySelector.ReturnType, + resultSelector.ReturnType), + outer, + inner, + outerKeySelector, + innerKeySelector, + resultSelector); + } + + // inner join + return Expression.Call( + QueryableMethods.Join.MakeGenericMethod( + outer.Type.TryGetSequenceType(), + inner.Type.TryGetSequenceType(), + outerKeySelector.ReturnType, + resultSelector.ReturnType), + outer, + inner, + outerKeySelector, + innerKeySelector, + resultSelector); + } + } + } + else if (genericMethod == QueryableMethods.SelectManyWithoutCollectionSelector) + { + // SelectMany + var selectManySource = methodCallExpression.Arguments[0]; + if (selectManySource is MethodCallExpression groupJoinMethod + && groupJoinMethod.Method.IsGenericMethod + && groupJoinMethod.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin) + { + // GroupJoin + var outer = groupJoinMethod.Arguments[0]; + var inner = groupJoinMethod.Arguments[1]; + var outerKeySelector = groupJoinMethod.Arguments[2].UnwrapLambdaFromQuote(); + var innerKeySelector = groupJoinMethod.Arguments[3].UnwrapLambdaFromQuote(); + var groupJoinResultSelector = groupJoinMethod.Arguments[4].UnwrapLambdaFromQuote(); + + var selectManyResultSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + + var groupJoinResultSelectorBody = groupJoinResultSelector.Body; + var defaultIfEmpty = false; + + if (groupJoinResultSelectorBody is MethodCallExpression collectionEndingMethod + && collectionEndingMethod.Method.IsGenericMethod + && collectionEndingMethod.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument) + { + defaultIfEmpty = true; + groupJoinResultSelectorBody = collectionEndingMethod.Arguments[0]; + } + + var correlatedCollectionSelector = _selectManyVerifyingExpressionVisitor + .VerifyCollectionSelector( + groupJoinResultSelectorBody, groupJoinResultSelector.Parameters[1]); + + if (!correlatedCollectionSelector) + { + inner = ReplacingExpressionVisitor.Replace( + groupJoinResultSelector.Parameters[1], + inner, + groupJoinResultSelectorBody); + + inner = ReplacingExpressionVisitor.Replace( + selectManyResultSelector.Parameters[0], + inner, + selectManyResultSelector.Body); + + inner = Visit(inner); + + var resultSelector = Expression.Lambda( + innerKeySelector.Parameters[0], + groupJoinResultSelector.Parameters[0], + innerKeySelector.Parameters[0]); + + // join case + if (defaultIfEmpty) + { + // left join + return Expression.Call( + QueryableExtensions.LeftJoinMethodInfo.MakeGenericMethod( + outer.Type.TryGetSequenceType(), + inner.Type.TryGetSequenceType(), + outerKeySelector.ReturnType, + resultSelector.ReturnType), + outer, + inner, + outerKeySelector, + innerKeySelector, + resultSelector); + } + + // inner join + return Expression.Call( + QueryableMethods.Join.MakeGenericMethod( + outer.Type.TryGetSequenceType(), + inner.Type.TryGetSequenceType(), + outerKeySelector.ReturnType, + resultSelector.ReturnType), + outer, + inner, + outerKeySelector, + innerKeySelector, + resultSelector); + } + } + } + + return methodCallExpression; + } + + private sealed class SelectManyVerifyingExpressionVisitor : ExpressionVisitor + { + private readonly List _allowedParameters = new List(); + private readonly ISet _allowedMethods = new HashSet { nameof(Queryable.Where), nameof(Queryable.AsQueryable) }; + + private ParameterExpression _rootParameter; + private int _rootParameterCount; + private bool _correlated; + + public bool VerifyCollectionSelector(Expression body, ParameterExpression rootParameter) + { + _correlated = false; + _rootParameterCount = 0; + _rootParameter = rootParameter; + + Visit(body); + + if (_rootParameterCount == 1) + { + var expression = body; + while (expression != null) + { + if (expression is MemberExpression memberExpression) + { + expression = memberExpression.Expression; + } + else if (expression is MethodCallExpression methodCallExpression + && methodCallExpression.Method.DeclaringType == typeof(Queryable)) + { + expression = methodCallExpression.Arguments[0]; + } + else if (expression is ParameterExpression) + { + if (expression != _rootParameter) + { + _correlated = true; + } + + break; + } + else + { + _correlated = true; + break; + } + } + } + + _rootParameter = null; + + return _correlated; + } + + protected override Expression VisitLambda(Expression lambdaExpression) + { + Check.NotNull(lambdaExpression, nameof(lambdaExpression)); + + try + { + _allowedParameters.AddRange(lambdaExpression.Parameters); + + return base.VisitLambda(lambdaExpression); + } + finally + { + foreach (var parameter in lambdaExpression.Parameters) + { + _allowedParameters.Remove(parameter); + } + } + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + + if (_correlated) + { + return methodCallExpression; + } + + if (methodCallExpression.Method.DeclaringType == typeof(Queryable) + && !_allowedMethods.Contains(methodCallExpression.Method.Name)) + { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.Select) + { + var selector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + if (selector.Body == selector.Parameters[0]) + { + // identity projection is allowed + return methodCallExpression; + } + } + + _correlated = true; + + return methodCallExpression; + } + + return base.VisitMethodCall(methodCallExpression); + } + + protected override Expression VisitParameter(ParameterExpression parameterExpression) + { + Check.NotNull(parameterExpression, nameof(parameterExpression)); + + if (_allowedParameters.Contains(parameterExpression)) + { + return parameterExpression; + } + + if (parameterExpression == _rootParameter) + { + _rootParameterCount++; + + return parameterExpression; + } + + _correlated = true; + + return base.VisitParameter(parameterExpression); + } + } + } +} diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs index abd5960983f..ec4d6d230a5 100644 --- a/src/EFCore/Query/QueryTranslationPreprocessor.cs +++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs @@ -29,20 +29,24 @@ public virtual Expression Process([NotNull] Expression query) { Check.NotNull(query, nameof(query)); - query = new EnumerableToQueryableMethodConvertingExpressionVisitor().Visit(query); - query = new QueryMetadataExtractingExpressionVisitor(_queryCompilationContext).Visit(query); query = new InvocationExpressionRemovingExpressionVisitor().Visit(query); + + query = NormalizeQueryableMethodCall(query); + query = new VBToCSharpConvertingExpressionVisitor().Visit(query); query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query); - query = new GroupJoinFlatteningExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query); query = new SubqueryMemberPushdownExpressionVisitor(_queryCompilationContext.Model).Visit(query); - query = new NavigationExpandingExpressionVisitor(_queryCompilationContext, Dependencies.EvaluatableExpressionFilter).Expand( + query = new NavigationExpandingExpressionVisitor(this, _queryCompilationContext, Dependencies.EvaluatableExpressionFilter).Expand( query); query = new FunctionPreprocessingExpressionVisitor().Visit(query); return query; } + + public virtual Expression NormalizeQueryableMethodCall([NotNull] Expression expression) + => new QueryableMethodNormalizingExpressionVisitor(_queryCompilationContext) + .Visit(Check.NotNull(expression, nameof(expression))); } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs index e15f19960d6..7faeb15e4e1 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; +using System.Linq.Expressions; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; @@ -666,6 +667,109 @@ public EntityDto8282(Entity8282 entity) #endregion + #region Bug19708 + + [ConditionalFact] + public virtual void GroupJoin_SelectMany_in_defining_query_is_flattened() + { + using (CreateScratch(Seed19708, "19708")) + { + using var context = new MyContext19708(); + + var query = context.Set().ToList(); + + Assert.Collection(query, + t => AssertCustomerView(t, 1, "First", 1, "FirstChild"), + t => AssertCustomerView(t, 2, "Second", 2, "SecondChild1"), + t => AssertCustomerView(t, 2, "Second", 3, "SecondChild2"), + t => AssertCustomerView(t, 3, "Third", null, "")); + } + + static void AssertCustomerView( + CustomerView19708 actual, int id, string name, int? customerMembershipId, string customerMembershipName) + { + Assert.Equal(id, actual.Id); + Assert.Equal(name, actual.Name); + Assert.Equal(customerMembershipId, actual.CustomerMembershipId); + Assert.Equal(customerMembershipName, actual.CustomerMembershipName); + } + } + + private class MyContext19708 : DbContext + { + public DbSet Customers { get; set; } + public DbSet CustomerMemberships { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder + .UseInternalServiceProvider(InMemoryFixture.DefaultServiceProvider) + .UseInMemoryDatabase("19708"); + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().HasNoKey().ToQuery(Build_Customers_Sql_View_InMemory()); + } + + private Expression>> Build_Customers_Sql_View_InMemory() + { + Expression>> query = () => + from customer in Customers + join customerMembership in CustomerMemberships on customer.Id equals customerMembership.CustomerId into + nullableCustomerMemberships + from customerMembership in nullableCustomerMemberships.DefaultIfEmpty() + select new CustomerView19708 + { + Id = customer.Id, + Name = customer.Name, + CustomerMembershipId = customerMembership != null ? customerMembership.Id : default(int?), + CustomerMembershipName = customerMembership != null ? customerMembership.Name : "" + }; + return query; + } + } + + private static void Seed19708(MyContext19708 context) + { + var customer1 = new Customer19708 { Name = "First" }; + var customer2 = new Customer19708 { Name = "Second" }; + var customer3 = new Customer19708 { Name = "Third" }; + + var customerMembership1 = new CustomerMembership19708 { Name = "FirstChild", Customer = customer1 }; + var customerMembership2 = new CustomerMembership19708 { Name = "SecondChild1", Customer = customer2 }; + var customerMembership3 = new CustomerMembership19708 { Name = "SecondChild2", Customer = customer2 }; + + context.AddRange(customer1, customer2, customer3); + context.AddRange(customerMembership1, customerMembership2, customerMembership3); + + context.SaveChanges(); + } + + private class Customer19708 + { + public int Id { get; set; } + public string Name { get; set; } + } + + private class CustomerMembership19708 + { + public int Id { get; set; } + public string Name { get; set; } + public int CustomerId { get; set; } + public Customer19708 Customer { get; set; } + } + + private class CustomerView19708 + { + public int Id { get; set; } + public string Name { get; set; } + public int? CustomerMembershipId { get; set; } + public string CustomerMembershipName { get; set; } + } + + #endregion + #region SharedHelper private static InMemoryTestStore CreateScratch(Action seed, string databaseName) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index 163cc198cdf..704a38d93c2 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -6968,6 +6968,154 @@ public BugContext19138(DbContextOptions options) #endregion + #region Issue19708 + + [ConditionalFact] + public void GroupJoin_SelectMany_in_query_filter_gets_flattened() + { + using var _ = CreateDatabase19708(); + using var context = new BugContext19708(_options); + + var query = context.CustomerFilters.ToList(); + + AssertSql( + @"SELECT [c].[CustomerId], [c].[CustomerMembershipId] +FROM [CustomerFilters] AS [c] +WHERE ( + SELECT COUNT(*) + FROM [Customers] AS [c0] + LEFT JOIN [CustomerMemberships] AS [c1] ON [c0].[Id] = [c1].[CustomerId] + WHERE [c1].[Id] IS NOT NULL AND ([c0].[Id] = [c].[CustomerId])) > 0"); + } + + [ConditionalFact] + public void GroupJoin_SelectMany_in_defining_query_gets_flattened() + { + using var _ = CreateDatabase19708(); + using var context = new BugContext19708(_options); + + var query = context.Set().ToList(); + + Assert.Collection(query, + t => AssertCustomerView(t, 1, "First", 1, "FirstChild"), + t => AssertCustomerView(t, 2, "Second", 2, "SecondChild1"), + t => AssertCustomerView(t, 2, "Second", 3, "SecondChild2"), + t => AssertCustomerView(t, 3, "Third", null, "")); + + static void AssertCustomerView( + CustomerView19708 actual, int id, string name, int? customerMembershipId, string customerMembershipName) + { + Assert.Equal(id, actual.Id); + Assert.Equal(name, actual.Name); + Assert.Equal(customerMembershipId, actual.CustomerMembershipId); + Assert.Equal(customerMembershipName, actual.CustomerMembershipName); + } + + AssertSql( + @"SELECT [c].[Id], [c].[Name], [c0].[Id], [c0].[CustomerId], [c0].[Name] +FROM [Customers] AS [c] +LEFT JOIN [CustomerMemberships] AS [c0] ON [c].[Id] = [c0].[CustomerId]"); + } + + private SqlServerTestStore CreateDatabase19708() + => CreateTestStore( + () => new BugContext19708(_options), + context => + { + var customer1 = new Customer19708 { Name = "First" }; + var customer2 = new Customer19708 { Name = "Second" }; + var customer3 = new Customer19708 { Name = "Third" }; + + var customerMembership1 = new CustomerMembership19708 { Name = "FirstChild", Customer = customer1 }; + var customerMembership2 = new CustomerMembership19708 { Name = "SecondChild1", Customer = customer2 }; + var customerMembership3 = new CustomerMembership19708 { Name = "SecondChild2", Customer = customer2 }; + + context.AddRange(customer1, customer2, customer3); + context.AddRange(customerMembership1, customerMembership2, customerMembership3); + + context.SaveChanges(); + + ClearLog(); + }); + + private class BugContext19708 : DbContext + { + public BugContext19708(DbContextOptions options) + : base(options) + { + } + + public DbSet Customers { get; set; } + public DbSet CustomerMemberships { get; set; } + public DbSet CustomerFilters { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity() + .HasQueryFilter(e => (from a in (from c in Customers + join cm in CustomerMemberships on c.Id equals cm.CustomerId into g + from cm in g.DefaultIfEmpty() + select new + { + c.Id, + CustomerMembershipId = (int?)cm.Id + }) + where a.CustomerMembershipId != null && a.Id == e.CustomerId + select a).Count() > 0) + .HasKey(e => e.CustomerId); + + modelBuilder.Entity().HasNoKey().ToQuery(Build_Customers_Sql_View_InMemory()); + } + + private Expression>> Build_Customers_Sql_View_InMemory() + { + Expression>> query = () => + from customer in Customers + join customerMembership in CustomerMemberships on customer.Id equals customerMembership.CustomerId into + nullableCustomerMemberships + from customerMembership in nullableCustomerMemberships.DefaultIfEmpty() + select new CustomerView19708 + { + Id = customer.Id, + Name = customer.Name, + CustomerMembershipId = customerMembership != null ? customerMembership.Id : default(int?), + CustomerMembershipName = customerMembership != null ? customerMembership.Name : "" + }; + return query; + } + } + + private class Customer19708 + { + public int Id { get; set; } + public string Name { get; set; } + } + + private class CustomerMembership19708 + { + public int Id { get; set; } + public string Name { get; set; } + + public int CustomerId { get; set; } + public Customer19708 Customer { get; set; } + } + + private class CustomerFilter19708 + { + public int CustomerId { get; set; } + public int CustomerMembershipId { get; set; } + } + + private class CustomerView19708 + { + public int Id { get; set; } + public string Name { get; set; } + public int? CustomerMembershipId { get; set; } + public string CustomerMembershipName { get; set; } + } + + #endregion + private DbContextOptions _options; private SqlServerTestStore CreateTestStore(