diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index eb21110fa42..8d5c3e31279 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -45,7 +45,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor( _model = queryCompilationContext.Model; _sqlExpressionFactory = sqlExpressionFactory; _sqlTranslator = new CosmosSqlTranslatingExpressionVisitor( - _model, + queryCompilationContext, sqlExpressionFactory, memberTranslatorProvider, methodCallTranslatorProvider); diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 6f5b0cc7615..e759edb57cd 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -2,9 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; +using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; @@ -23,6 +26,14 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal /// public class CosmosSqlTranslatingExpressionVisitor : ExpressionVisitor { + private const string _runtimeParameterPrefix = QueryCompilationContext.QueryParameterPrefix + "entity_equality_"; + + private static readonly MethodInfo _parameterValueExtractor = + typeof(CosmosSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor)); + private static readonly MethodInfo _parameterListValueExtractor = + typeof(CosmosSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor)); + + private readonly QueryCompilationContext _queryCompilationContext; private readonly IModel _model; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly IMemberTranslatorProvider _memberTranslatorProvider; @@ -36,12 +47,13 @@ public class CosmosSqlTranslatingExpressionVisitor : ExpressionVisitor /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public CosmosSqlTranslatingExpressionVisitor( - [NotNull] IModel model, + [NotNull] QueryCompilationContext queryCompilationContext, [NotNull] ISqlExpressionFactory sqlExpressionFactory, [NotNull] IMemberTranslatorProvider memberTranslatorProvider, [NotNull] IMethodCallTranslatorProvider methodCallTranslatorProvider) { - _model = model; + _queryCompilationContext = queryCompilationContext; + _model = queryCompilationContext.Model; _sqlExpressionFactory = sqlExpressionFactory; _memberTranslatorProvider = memberTranslatorProvider; _methodCallTranslatorProvider = methodCallTranslatorProvider; @@ -102,6 +114,20 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) ifFalse)); } + var left = TryRemoveImplicitConvert(binaryExpression.Left); + var right = TryRemoveImplicitConvert(binaryExpression.Right); + + var visitedLeft = Visit(left); + var visitedRight = Visit(right); + + if ((binaryExpression.NodeType == ExpressionType.Equal + || binaryExpression.NodeType == ExpressionType.NotEqual) + // Visited expression could be null, We need to pass MemberInitExpression + && TryRewriteEntityEquality(binaryExpression.NodeType, visitedLeft ?? left, visitedRight ?? right, out var result)) + { + return result; + } + var uncheckedNodeTypeVariant = binaryExpression.NodeType switch { ExpressionType.AddChecked => ExpressionType.Add, @@ -110,14 +136,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) _ => binaryExpression.NodeType }; - var left = TryRemoveImplicitConvert(binaryExpression.Left); - var right = TryRemoveImplicitConvert(binaryExpression.Right); - - left = Visit(left); - right = Visit(right); - - return TranslationFailed(binaryExpression.Left, left, out var sqlLeft) - || TranslationFailed(binaryExpression.Right, right, out var sqlRight) + return TranslationFailed(binaryExpression.Left, visitedLeft, out var sqlLeft) + || TranslationFailed(binaryExpression.Right, visitedRight, out var sqlRight) ? null : _sqlExpressionFactory.MakeBinary( uncheckedNodeTypeVariant, @@ -169,6 +189,7 @@ protected override Expression VisitExtension(Expression extensionExpression) switch (extensionExpression) { case EntityProjectionExpression _: + case EntityReferenceExpression _: case SqlExpression _: return extensionExpression; @@ -189,7 +210,7 @@ protected override Expression VisitExtension(Expression extensionExpression) return new EntityReferenceExpression(entityProjectionExpression); } - throw new InvalidOperationException("Randomization"); + return null; case ProjectionBindingExpression projectionBindingExpression: return projectionBindingExpression.ProjectionMember != null @@ -269,21 +290,126 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return TryBindMember(Visit(source), MemberIdentity.Create(propertyName)); } - if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) + SqlExpression sqlObject = null; + SqlExpression[] arguments; + var method = methodCallExpression.Method; + + if (method.Name == nameof(object.Equals) + && methodCallExpression.Object != null + && methodCallExpression.Arguments.Count == 1) { - return null; + var left = Visit(methodCallExpression.Object); + var right = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteEntityEquality(ExpressionType.Equal, + left ?? methodCallExpression.Object, + right ?? methodCallExpression.Arguments[0], + out var result)) + { + return result; + } + + if (left is SqlExpression leftSql + && right is SqlExpression rightSql) + { + sqlObject = leftSql; + arguments = new SqlExpression[1] { rightSql }; + } + else + { + return null; + } + } + else if (method.Name == nameof(object.Equals) + && methodCallExpression.Object == null + && methodCallExpression.Arguments.Count == 2) + { + var left = Visit(methodCallExpression.Arguments[0]); + var right = Visit(methodCallExpression.Arguments[1]); + + if (TryRewriteEntityEquality(ExpressionType.Equal, + left ?? methodCallExpression.Arguments[0], + right ?? methodCallExpression.Arguments[1], + out var result)) + { + return result; + } + + if (left is SqlExpression leftSql + && right is SqlExpression rightSql) + { + arguments = new SqlExpression[2] { leftSql, rightSql }; + } + else + { + return null; + } } + else if (method.IsGenericMethod + && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) + { + var enumerable = Visit(methodCallExpression.Arguments[0]); + var item = Visit(methodCallExpression.Arguments[1]); - var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; - for (var i = 0; i < arguments.Length; i++) + if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[1], out var result)) + { + return result; + } + + if (enumerable is SqlExpression sqlEnumerable + && item is SqlExpression sqlItem) + { + arguments = new SqlExpression[2] { sqlEnumerable, sqlItem }; + } + else + { + return null; + } + } + else if (method.Name == nameof(IList.Contains) + && methodCallExpression.Arguments.Count == 1 + && method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any( + t => t == typeof(IList) + || (t.IsGenericType + && t.GetGenericTypeDefinition() == typeof(ICollection<>)))) { - var argument = methodCallExpression.Arguments[i]; - if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + var enumerable = Visit(methodCallExpression.Object); + var item = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[0], out var result)) + { + return result; + } + + if (enumerable is SqlExpression sqlEnumerable + && item is SqlExpression sqlItem) + { + sqlObject = sqlEnumerable; + arguments = new SqlExpression[1] { sqlItem }; + } + else + { + return null; + } + } + else + { + if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) { return null; } - arguments[i] = sqlArgument; + arguments = new SqlExpression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + var argument = methodCallExpression.Arguments[i]; + if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + { + return null; + } + + arguments[i] = sqlArgument; + } } return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); @@ -414,17 +540,199 @@ private static Expression TryRemoveImplicitConvert(Expression expression) return expression; } - private SqlConstantExpression GetConstantOrNull(Expression expression) + private bool TryRewriteContainsEntity(Expression source, Expression item, out Expression result) { - if (CanEvaluate(expression)) + result = null; + + if (!(item is EntityReferenceExpression itemEntityReference)) { - var value = Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(); - return new SqlConstantExpression(Expression.Constant(value, expression.Type), null); + return false; } - return null; + var entityType = itemEntityReference.EntityType; + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + var property = primaryKeyProperties[0]; + Expression rewrittenSource; + switch (source) + { + case SqlConstantExpression sqlConstantExpression: + var values = (IEnumerable)sqlConstantExpression.Value; + var propertyValueList = (IList)Activator.CreateInstance(typeof(List<>).MakeGenericType(property.ClrType.MakeNullable())); + var propertyGetter = property.GetGetter(); + foreach (var value in values) + { + propertyValueList.Add(propertyGetter.GetClrValue(value)); + } + + rewrittenSource = Expression.Constant(propertyValueList); + break; + + case SqlParameterExpression sqlParameterExpression + when sqlParameterExpression.Name.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal): + var lambda = Expression.Lambda( + Expression.Call( + _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(sqlParameterExpression.Name, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter + ); + + var newParameterName = + $"{_runtimeParameterPrefix}" + + $"{sqlParameterExpression.Name.Substring(QueryCompilationContext.QueryParameterPrefix.Length)}_{property.Name}"; + + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + break; + + default: + return false; + } + + result = Visit(Expression.Call( + EnumerableMethods.Contains.MakeGenericMethod(property.ClrType.MakeNullable()), + rewrittenSource, + CreatePropertyAccessExpression(item, property))); + + return true; + } + + private bool TryRewriteEntityEquality(ExpressionType nodeType, Expression left, Expression right, out Expression result) + { + var leftEntityReference = left as EntityReferenceExpression; + var rightEntityReference = right as EntityReferenceExpression; + + if (leftEntityReference == null + && rightEntityReference == null) + { + result = null; + return false; + } + + if (IsNullSqlConstantExpression(left) + || IsNullSqlConstantExpression(right)) + { + var nonNullEntityReference = IsNullSqlConstantExpression(left) ? rightEntityReference : leftEntityReference; + var entityType1 = nonNullEntityReference.EntityType; + var primaryKeyProperties1 = entityType1.FindPrimaryKey()?.Properties; + if (primaryKeyProperties1 == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType1.DisplayName())); + } + + result = Visit(primaryKeyProperties1.Select(p => + Expression.MakeBinary( + nodeType, CreatePropertyAccessExpression(nonNullEntityReference, p), Expression.Constant(null, p.ClrType.MakeNullable()))) + .Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r))); + + return true; + } + + var leftEntityType = leftEntityReference?.EntityType; + var rightEntityType = rightEntityReference?.EntityType; + var entityType = leftEntityType ?? rightEntityType; + + Debug.Assert(entityType != null, "At least either side should be entityReference so entityType should be non-null."); + + if (leftEntityType != null + && rightEntityType != null + && leftEntityType.GetRootType() != rightEntityType.GetRootType()) + { + result = _sqlExpressionFactory.Constant(false); + return true; + } + + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + + result = Visit(primaryKeyProperties.Select(p => + Expression.MakeBinary( + nodeType, + CreatePropertyAccessExpression(left, p), + CreatePropertyAccessExpression(right, p))) + .Aggregate((l, r) => Expression.AndAlso(l, r))); + + return true; + } + + private Expression CreatePropertyAccessExpression(Expression target, IProperty property) + { + switch (target) + { + case SqlConstantExpression sqlConstantExpression: + return Expression.Constant( + property.GetGetter().GetClrValue(sqlConstantExpression.Value), property.ClrType.MakeNullable()); + + case SqlParameterExpression sqlParameterExpression + when sqlParameterExpression.Name.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal): + var lambda = Expression.Lambda( + Expression.Call( + _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(sqlParameterExpression.Name, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter); + + var newParameterName = + $"{_runtimeParameterPrefix}" + + $"{sqlParameterExpression.Name.Substring(QueryCompilationContext.QueryParameterPrefix.Length)}_{property.Name}"; + + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + + case MemberInitExpression memberInitExpression + when memberInitExpression.Bindings.SingleOrDefault( + mb => mb.Member.Name == property.Name) is MemberAssignment memberAssignment: + return memberAssignment.Expression; + + default: + return target.CreateEFPropertyExpression(property); + } } + private static T ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) + { + var baseParameter = context.ParameterValues[baseParameterName]; + return baseParameter == null ? (T)(object)null : (T)property.GetGetter().GetClrValue(baseParameter); + } + + private static List ParameterListValueExtractor( + QueryContext context, string baseParameterName, IProperty property) + { + if (!(context.ParameterValues[baseParameterName] is IEnumerable baseListParameter)) + { + return null; + } + + var getter = property.GetGetter(); + return baseListParameter.Select(e => e != null ? (TProperty)getter.GetClrValue(e) : (TProperty)(object)null).ToList(); + } + + private static bool IsNullSqlConstantExpression(Expression expression) + => expression is SqlConstantExpression sqlConstant && sqlConstant.Value == null; + + private SqlConstantExpression GetConstantOrNull(Expression expression) + => CanEvaluate(expression) + ? new SqlConstantExpression( + Expression.Constant( + Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(), + expression.Type), + null) + : null; + private static bool CanEvaluate(Expression expression) { #pragma warning disable IDE0066 // Convert switch statement to expression @@ -485,13 +793,10 @@ private EntityReferenceExpression(EntityProjectionExpression parameter, Type typ public Expression Convert(Type type) { - if (type == typeof(object) // Ignore object conversion - || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface - { - return this; - } - - return new EntityReferenceExpression(ParameterEntity, type); + return type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type) // Ignore conversion to base/interface + ? this + : new EntityReferenceExpression(ParameterEntity, type); } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index f461d69e13d..44cf9fd9b40 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Linq; @@ -23,51 +24,47 @@ namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor { - private const string _compiledQueryParameterPrefix = "__"; + private const string _runtimeParameterPrefix = QueryCompilationContext.QueryParameterPrefix + "entity_equality_"; private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; - private static readonly MethodInfo _getParameterValueMethodInfo - = typeof(InMemoryExpressionTranslatingExpressionVisitor) - .GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue)); - - private static readonly MethodInfo _likeMethodInfo - = typeof(DbFunctionsExtensions).GetRuntimeMethod( - nameof(DbFunctionsExtensions.Like), - new[] { typeof(DbFunctions), typeof(string), typeof(string) }); - - private static readonly MethodInfo _likeMethodInfoWithEscape - = typeof(DbFunctionsExtensions).GetRuntimeMethod( - nameof(DbFunctionsExtensions.Like), - new[] { typeof(DbFunctions), typeof(string), typeof(string), typeof(string) }); - - private static readonly MethodInfo _inMemoryLikeMethodInfo - = typeof(InMemoryExpressionTranslatingExpressionVisitor) - .GetTypeInfo().GetDeclaredMethod(nameof(InMemoryLike)); + private static readonly MethodInfo _parameterValueExtractor = + typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor)); + private static readonly MethodInfo _parameterListValueExtractor = + typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor)); + private static readonly MethodInfo _getParameterValueMethodInfo = + typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue)); + private static readonly MethodInfo _likeMethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod( + nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string) }); + private static readonly MethodInfo _likeMethodInfoWithEscape = typeof(DbFunctionsExtensions).GetRuntimeMethod( + nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string), typeof(string) }); + private static readonly MethodInfo _inMemoryLikeMethodInfo = + typeof(InMemoryExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(InMemoryLike)); // Regex special chars defined here: // https://msdn.microsoft.com/en-us/library/4edbef7e(v=vs.110).aspx private static readonly char[] _regexSpecialChars = { '.', '$', '^', '{', '[', '(', '|', ')', '*', '+', '?', '\\' }; - private static readonly string _defaultEscapeRegexCharsPattern - = BuildEscapeRegexCharsPattern(_regexSpecialChars); + private static readonly string _defaultEscapeRegexCharsPattern = BuildEscapeRegexCharsPattern(_regexSpecialChars); private static readonly TimeSpan _regexTimeout = TimeSpan.FromMilliseconds(value: 1000.0); private static string BuildEscapeRegexCharsPattern(IEnumerable regexSpecialChars) => string.Join("|", regexSpecialChars.Select(c => @"\" + c)); + private readonly QueryCompilationContext _queryCompilationContext; private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; private readonly EntityReferenceFindingExpressionVisitor _entityReferenceFindingExpressionVisitor; private readonly IModel _model; public InMemoryExpressionTranslatingExpressionVisitor( - [NotNull] QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor, - [NotNull] IModel model) + [NotNull] QueryCompilationContext queryCompilationContext, + [NotNull] QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) { + _queryCompilationContext = queryCompilationContext; _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; _entityReferenceFindingExpressionVisitor = new EntityReferenceFindingExpressionVisitor(); - _model = model; + _model = queryCompilationContext.Model; } public virtual Expression Translate([NotNull] Expression expression) @@ -92,6 +89,15 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return null; } + if ((binaryExpression.NodeType == ExpressionType.Equal + || binaryExpression.NodeType == ExpressionType.NotEqual) + // Visited expression could be null, We need to pass MemberInitExpression + && TryRewriteEntityEquality( + binaryExpression.NodeType, newLeft ?? binaryExpression.Left, newRight ?? binaryExpression.Right, out var result)) + { + return result; + } + if (IsConvertedToNullable(newLeft, binaryExpression.Left) || IsConvertedToNullable(newRight, binaryExpression.Right)) { @@ -170,6 +176,7 @@ protected override Expression VisitExtension(Expression extensionExpression) switch (extensionExpression) { case EntityProjectionExpression _: + case EntityReferenceExpression _: return extensionExpression; case EntityShaperExpression entityShaperExpression: @@ -287,13 +294,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } var selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter); - var method = GetMethod(); - method = method.GetGenericArguments().Length == 2 - ? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) - : method.MakeGenericMethod(typeof(ValueBuffer)); + var method2 = GetMethod(); + method2 = method2.GetGenericArguments().Length == 2 + ? method2.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) + : method2.MakeGenericMethod(typeof(ValueBuffer)); return Expression.Call( - method, + method2, groupByShaperExpression.GroupingParameter, selector); @@ -394,25 +401,127 @@ MethodInfo GetMethod() return Expression.Call(_inMemoryLikeMethodInfo, visitedArguments); } - // MethodCall translators - var @object = Visit(methodCallExpression.Object); - if (TranslationFailed(methodCallExpression.Object, @object)) + Expression @object = null; + Expression[] arguments; + var method = methodCallExpression.Method; + + if (method.Name == nameof(object.Equals) + && methodCallExpression.Object != null + && methodCallExpression.Arguments.Count == 1) { - return null; + var left = Visit(methodCallExpression.Object); + var right = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteEntityEquality(ExpressionType.Equal, + left ?? methodCallExpression.Object, + right ?? methodCallExpression.Arguments[0], + out var result)) + { + return result; + } + + if (TranslationFailed(left) + || TranslationFailed(right)) + { + return null; + } + + @object = left; + arguments = new Expression[1] { right }; } + else if (method.Name == nameof(object.Equals) + && methodCallExpression.Object == null + && methodCallExpression.Arguments.Count == 2) + { + var left = Visit(methodCallExpression.Arguments[0]); + var right = Visit(methodCallExpression.Arguments[1]); - var arguments = new Expression[methodCallExpression.Arguments.Count]; - var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); - for (var i = 0; i < arguments.Length; i++) + if (TryRewriteEntityEquality(ExpressionType.Equal, + left ?? methodCallExpression.Arguments[0], + right ?? methodCallExpression.Arguments[1], + out var result)) + { + return result; + } + + if (TranslationFailed(left) + || TranslationFailed(right)) + { + return null; + } + + arguments = new Expression[2] { left, right }; + } + else if (method.IsGenericMethod + && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) + { + var enumerable = Visit(methodCallExpression.Arguments[0]); + var item = Visit(methodCallExpression.Arguments[1]); + + if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[1], out var result)) + { + return result; + } + + if (TranslationFailed(enumerable) + || TranslationFailed(item)) + { + return null; + } + + arguments = new Expression[2] { enumerable, item }; + } + else if (method.Name == nameof(IList.Contains) + && methodCallExpression.Arguments.Count == 1 + && method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any( + t => t == typeof(IList) + || (t.IsGenericType + && t.GetGenericTypeDefinition() == typeof(ICollection<>)))) + { + var enumerable = Visit(methodCallExpression.Object); + var item = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[0], out var result)) + { + return result; + } + + if (TranslationFailed(enumerable) + || TranslationFailed(item)) + { + return null; + } + + @object = enumerable; + arguments = new Expression[1] { item }; + } + else { - var argument = Visit(methodCallExpression.Arguments[i]); - if (TranslationFailed(methodCallExpression.Arguments[i], argument)) + @object = Visit(methodCallExpression.Object); + if (TranslationFailed(methodCallExpression.Object, @object)) { return null; } - // if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type, - // so we are forced to cast back to the original type + arguments = new Expression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + var argument = Visit(methodCallExpression.Arguments[i]); + if (TranslationFailed(methodCallExpression.Arguments[i], argument)) + { + return null; + } + + arguments[i] = argument; + } + } + + // if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type, + // so we are forced to cast back to the original type + var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); + for (var i = 0; i < arguments.Length; i++) + { + var argument = arguments[i]; if (IsConvertedToNullable(argument, methodCallExpression.Arguments[i]) && !parameterTypes[i].IsAssignableFrom(argument.Type)) { @@ -496,7 +605,7 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres { Check.NotNull(parameterExpression, nameof(parameterExpression)); - if (parameterExpression.Name.StartsWith(_compiledQueryParameterPrefix, StringComparison.Ordinal)) + if (parameterExpression.Name.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal)) { return Expression.Call( _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), @@ -763,10 +872,246 @@ private IProperty FindProperty(Expression expression) return null; } + private bool TryRewriteContainsEntity(Expression source, Expression item, out Expression result) + { + result = null; + + if (!(item is EntityReferenceExpression itemEntityReference)) + { + return false; + } + + var entityType = itemEntityReference.EntityType; + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + var property = primaryKeyProperties[0]; + Expression rewrittenSource; + switch (source) + { + case ConstantExpression constantExpression: + var values = (IEnumerable)constantExpression.Value; + var propertyValueList = (IList)Activator.CreateInstance(typeof(List<>).MakeGenericType(property.ClrType.MakeNullable())); + var propertyGetter = property.GetGetter(); + foreach (var value in values) + { + propertyValueList.Add(propertyGetter.GetClrValue(value)); + } + + rewrittenSource = Expression.Constant(propertyValueList); + break; + + case MethodCallExpression methodCallExpression + when methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == _getParameterValueMethodInfo: + var parameterName = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value; + var lambda = Expression.Lambda( + Expression.Call( + _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterName, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter + ); + + var newParameterName = + $"{_runtimeParameterPrefix}" + + $"{parameterName.Substring(QueryCompilationContext.QueryParameterPrefix.Length)}_{property.Name}"; + + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + break; + + default: + return false; + } + + result = Visit(Expression.Call( + EnumerableMethods.Contains.MakeGenericMethod(property.ClrType.MakeNullable()), + rewrittenSource, + CreatePropertyAccessExpression(item, property))); + + return true; + } + + private bool TryRewriteEntityEquality(ExpressionType nodeType, Expression left, Expression right, out Expression result) + { + var leftEntityReference = left as EntityReferenceExpression; + var rightEntityReference = right as EntityReferenceExpression; + + if (leftEntityReference == null + && rightEntityReference == null) + { + result = null; + return false; + } + + if (IsNullConstantExpression(left) + || IsNullConstantExpression(right)) + { + var nonNullEntityReference = IsNullConstantExpression(left) ? rightEntityReference : leftEntityReference; + var entityType1 = nonNullEntityReference.EntityType; + var primaryKeyProperties1 = entityType1.FindPrimaryKey()?.Properties; + if (primaryKeyProperties1 == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType1.DisplayName())); + } + + result = Visit(primaryKeyProperties1.Select(p => + Expression.MakeBinary( + nodeType, CreatePropertyAccessExpression(nonNullEntityReference, p), Expression.Constant(null, p.ClrType.MakeNullable()))) + .Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r))); + + return true; + } + + var leftEntityType = leftEntityReference?.EntityType; + var rightEntityType = rightEntityReference?.EntityType; + var entityType = leftEntityType ?? rightEntityType; + + Debug.Assert(entityType != null, "At least either side should be entityReference so entityType should be non-null."); + + if (leftEntityType != null + && rightEntityType != null + && leftEntityType.GetRootType() != rightEntityType.GetRootType()) + { + result = Expression.Constant(false); + return true; + } + + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1 + && (leftEntityReference?.SubqueryEntity != null + || rightEntityReference?.SubqueryEntity != null)) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + result = Visit(primaryKeyProperties.Select(p => + Expression.MakeBinary( + nodeType, + CreatePropertyAccessExpression(left, p), + CreatePropertyAccessExpression(right, p))) + .Aggregate((l, r) => Expression.AndAlso(l, r))); + + return true; + } + + private Expression CreatePropertyAccessExpression(Expression target, IProperty property) + { + switch (target) + { + case ConstantExpression constantExpression: + return Expression.Constant( + property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType.MakeNullable()); + + case MethodCallExpression methodCallExpression + when methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == _getParameterValueMethodInfo: + var parameterName = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value; + var lambda = Expression.Lambda( + Expression.Call( + _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterName, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter); + + var newParameterName = + $"{_runtimeParameterPrefix}" + + $"{parameterName.Substring(QueryCompilationContext.QueryParameterPrefix.Length)}_{property.Name}"; + + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + + case MemberInitExpression memberInitExpression + when memberInitExpression.Bindings.SingleOrDefault( + mb => mb.Member.Name == property.Name) is MemberAssignment memberAssignment: + return memberAssignment.Expression.Type.IsNullableType() + ? memberAssignment.Expression + : Expression.Convert(memberAssignment.Expression, property.ClrType.MakeNullable()); + + case NewExpression newExpression + when CanEvaluate(newExpression): + return CreatePropertyAccessExpression(GetValue(newExpression), property); + + case MemberInitExpression memberInitExpression + when CanEvaluate(memberInitExpression): + return CreatePropertyAccessExpression(GetValue(memberInitExpression), property); + + default: + return target.CreateEFPropertyExpression(property); + } + } + + private static T ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) + { + var baseParameter = context.ParameterValues[baseParameterName]; + return baseParameter == null ? (T)(object)null : (T)property.GetGetter().GetClrValue(baseParameter); + } + + private static List ParameterListValueExtractor( + QueryContext context, string baseParameterName, IProperty property) + { + if (!(context.ParameterValues[baseParameterName] is IEnumerable baseListParameter)) + { + return null; + } + + var getter = property.GetGetter(); + return baseListParameter.Select(e => e != null ? (TProperty)getter.GetClrValue(e) : (TProperty)(object)null).ToList(); + } + + private static ConstantExpression GetValue(Expression expression) + => Expression.Constant( + Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(), + expression.Type); + + private static bool CanEvaluate(Expression expression) + { +#pragma warning disable IDE0066 // Convert switch statement to expression + switch (expression) +#pragma warning restore IDE0066 // Convert switch statement to expression + { + case ConstantExpression constantExpression: + return true; + + case NewExpression newExpression: + return newExpression.Arguments.All(e => CanEvaluate(e)); + + case MemberInitExpression memberInitExpression: + return CanEvaluate(memberInitExpression.NewExpression) + && memberInitExpression.Bindings.All( + mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); + + default: + return false; + } + } + + private static bool IsNullConstantExpression(Expression expression) + => expression is ConstantExpression constantExpression && constantExpression.Value == null; + [DebuggerStepThrough] private static bool TranslationFailed(Expression original, Expression translation) => original != null && (translation == null || translation is EntityReferenceExpression); + private static bool TranslationFailed(Expression translation) + => translation == null || translation is EntityReferenceExpression; + private static bool InMemoryLike(string matchExpression, string pattern, string escapeCharacter) { //TODO: this fixes https://github.com/aspnet/EntityFramework/issues/8656 by insisting that diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 2d156674ce4..6a467f5de42 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -7,12 +7,11 @@ using System.Linq.Expressions; using System.Reflection; using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.InMemory.Internal; -using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; @@ -20,8 +19,6 @@ namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { public class InMemoryQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor { - private static readonly MethodInfo _efPropertyMethod = typeof(EF).GetTypeInfo().GetDeclaredMethod(nameof(EF.Property)); - private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator; private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor; private readonly InMemoryProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor; @@ -32,7 +29,7 @@ public InMemoryQueryableMethodTranslatingExpressionVisitor( [NotNull] QueryCompilationContext queryCompilationContext) : base(dependencies, subquery: false) { - _expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(this, queryCompilationContext.Model); + _expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(queryCompilationContext, this); _weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_expressionTranslator); _projectionBindingExpressionVisitor = new InMemoryProjectionBindingExpressionVisitor(this, _expressionTranslator); _model = queryCompilationContext.Model; @@ -402,16 +399,14 @@ protected override ShapedQueryExpression TranslateJoin( Check.NotNull(inner, nameof(inner)); Check.NotNull(resultSelector, nameof(resultSelector)); - outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector); - innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector); + (outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector); + if (outerKeySelector == null || innerKeySelector == null) { return null; } - (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); - var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, resultSelector.Parameters[1].Type); @@ -429,6 +424,71 @@ protected override ShapedQueryExpression TranslateJoin( transparentIdentifierType); } + private (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) ProcessJoinKeySelector( + ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector) + { + var left = RemapLambdaBody(outer, outerKeySelector); + var right = RemapLambdaBody(inner, innerKeySelector); + + var joinCondition = TranslateExpression(Expression.Equal(left, right)); + + var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition); + + if (outerKeyBody == null + || innerKeyBody == null) + { + return (null, null); + } + + outerKeySelector = Expression.Lambda(outerKeyBody, ((InMemoryQueryExpression)outer.QueryExpression).CurrentParameter); + innerKeySelector = Expression.Lambda(innerKeyBody, ((InMemoryQueryExpression)inner.QueryExpression).CurrentParameter); + + return AlignKeySelectorTypes(outerKeySelector, innerKeySelector); + } + + private static (Expression, Expression) DecomposeJoinCondition(Expression joinCondition) + { + var leftExpressions = new List(); + var rightExpressions = new List(); + + return ProcessJoinCondition(joinCondition, leftExpressions, rightExpressions) + ? leftExpressions.Count == 1 + ? (leftExpressions[0], rightExpressions[0]) + : (CreateAnonymousObject(leftExpressions), CreateAnonymousObject(rightExpressions)) + : (null, null); + + static Expression CreateAnonymousObject(List expressions) + => Expression.New( + AnonymousObject.AnonymousObjectCtor, + Expression.NewArrayInit( + typeof(object), + expressions.Select(e => Expression.Convert(e, typeof(object))))); + } + + + private static bool ProcessJoinCondition( + Expression joinCondition, List leftExpressions, List rightExpressions) + { + if (joinCondition is BinaryExpression binaryExpression) + { + if (binaryExpression.NodeType == ExpressionType.Equal) + { + leftExpressions.Add(binaryExpression.Left); + rightExpressions.Add(binaryExpression.Right); + + return true; + } + + if (binaryExpression.NodeType == ExpressionType.AndAlso) + { + return ProcessJoinCondition(binaryExpression.Left, leftExpressions, rightExpressions) + && ProcessJoinCondition(binaryExpression.Right, leftExpressions, rightExpressions); + } + } + + return false; + } + private static (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector) { @@ -477,15 +537,14 @@ protected override ShapedQueryExpression TranslateLeftJoin( Check.NotNull(inner, nameof(inner)); Check.NotNull(resultSelector, nameof(resultSelector)); - outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector); - innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector); + (outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector); + if (outerKeySelector == null || innerKeySelector == null) { return null; } - (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, @@ -579,14 +638,8 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s var discriminatorProperty = entityType.GetDiscriminatorProperty(); var parameter = Expression.Parameter(entityType.ClrType); - var callEFProperty = Expression.Call( - _efPropertyMethod.MakeGenericMethod( - discriminatorProperty.ClrType), - parameter, - Expression.Constant(discriminatorProperty.Name)); - var equals = Expression.Equal( - callEFProperty, + parameter.CreateEFPropertyExpression(discriminatorProperty), Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); foreach (var derivedDerivedType in derivedType.GetDerivedTypes()) @@ -594,7 +647,7 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s equals = Expression.OrElse( equals, Expression.Equal( - callEFProperty, + parameter.CreateEFPropertyExpression(discriminatorProperty), Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 1146f9c6701..116047dea4e 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -22,6 +22,7 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator; private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor; private readonly RelationalProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor; + private readonly QueryCompilationContext _queryCompilationContext; private readonly IModel _model; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly bool _subquery; @@ -54,7 +55,7 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor( : base(parentVisitor.Dependencies, subquery: true) { RelationalDependencies = parentVisitor.RelationalDependencies; - _model = parentVisitor._model; + _queryCompilationContext = parentVisitor._queryCompilationContext; _sqlTranslator = parentVisitor._sqlTranslator; _weakEntityExpandingExpressionVisitor = parentVisitor._weakEntityExpandingExpressionVisitor; _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator); @@ -116,7 +117,7 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen { Check.NotNull(elementType, nameof(elementType)); - var entityType = _model.FindEntityType(elementType); + var entityType = _queryCompilationContext.Model.FindEntityType(elementType); var queryExpression = _sqlExpressionFactory.Select(entityType); return CreateShapedQueryExpression(entityType, queryExpression); diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index c9057094e8b..d576ed7e5a7 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; +using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Linq.Expressions; @@ -18,6 +20,15 @@ namespace Microsoft.EntityFrameworkCore.Query { public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor { + + private const string _runtimeParameterPrefix = QueryCompilationContext.QueryParameterPrefix + "entity_equality_"; + + private static readonly MethodInfo _parameterValueExtractor = + typeof(RelationalSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor)); + private static readonly MethodInfo _parameterListValueExtractor = + typeof(RelationalSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor)); + + private readonly QueryCompilationContext _queryCompilationContext; private readonly IModel _model; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; @@ -34,7 +45,7 @@ public RelationalSqlTranslatingExpressionVisitor( Dependencies = dependencies; _sqlExpressionFactory = dependencies.SqlExpressionFactory; - + _queryCompilationContext = queryCompilationContext; _model = queryCompilationContext.Model; _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; _sqlTypeMappingVerifyingExpressionVisitor = new SqlTypeMappingVerifyingExpressionVisitor(); @@ -233,6 +244,20 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return Visit(ConvertAnonymousObjectEqualityComparison(binaryExpression)); } + var left = TryRemoveImplicitConvert(binaryExpression.Left); + var right = TryRemoveImplicitConvert(binaryExpression.Right); + + var visitedLeft = Visit(left); + var visitedRight = Visit(right); + + if ((binaryExpression.NodeType == ExpressionType.Equal + || binaryExpression.NodeType == ExpressionType.NotEqual) + // Visited expression could be null, We need to pass MemberInitExpression + && TryRewriteEntityEquality(binaryExpression.NodeType, visitedLeft ?? left, visitedRight ?? right, out var result)) + { + return result; + } + var uncheckedNodeTypeVariant = binaryExpression.NodeType switch { ExpressionType.AddChecked => ExpressionType.Add, @@ -241,11 +266,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) _ => binaryExpression.NodeType }; - var left = TryRemoveImplicitConvert(binaryExpression.Left); - var right = TryRemoveImplicitConvert(binaryExpression.Right); - - return TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) - || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight) + return TranslationFailed(binaryExpression.Left, visitedLeft, out var sqlLeft) + || TranslationFailed(binaryExpression.Right, visitedRight, out var sqlRight) ? null : uncheckedNodeTypeVariant == ExpressionType.Coalesce ? _sqlExpressionFactory.Coalesce(sqlLeft, sqlRight) @@ -281,6 +303,7 @@ protected override Expression VisitExtension(Expression extensionExpression) switch (extensionExpression) { case EntityProjectionExpression _: + case EntityReferenceExpression _: case SqlExpression _: return extensionExpression; @@ -309,7 +332,7 @@ protected override Expression VisitMember(MemberExpression memberExpression) var innerExpression = Visit(memberExpression.Expression); return TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member)) - ?? (TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression) + ?? (TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression) ? null : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type)); } @@ -415,21 +438,126 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return new ScalarSubqueryExpression(subquery); } - if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) + SqlExpression sqlObject = null; + SqlExpression[] arguments; + var method = methodCallExpression.Method; + + if (method.Name == nameof(object.Equals) + && methodCallExpression.Object != null + && methodCallExpression.Arguments.Count == 1) { - return null; + var left = Visit(methodCallExpression.Object); + var right = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteEntityEquality(ExpressionType.Equal, + left ?? methodCallExpression.Object, + right ?? methodCallExpression.Arguments[0], + out var result)) + { + return result; + } + + if (left is SqlExpression leftSql + && right is SqlExpression rightSql) + { + sqlObject = leftSql; + arguments = new SqlExpression[1] { rightSql }; + } + else + { + return null; + } } + else if (method.Name == nameof(object.Equals) + && methodCallExpression.Object == null + && methodCallExpression.Arguments.Count == 2) + { + var left = Visit(methodCallExpression.Arguments[0]); + var right = Visit(methodCallExpression.Arguments[1]); - var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; - for (var i = 0; i < arguments.Length; i++) + if (TryRewriteEntityEquality(ExpressionType.Equal, + left ?? methodCallExpression.Arguments[0], + right ?? methodCallExpression.Arguments[1], + out var result)) + { + return result; + } + + if (left is SqlExpression leftSql + && right is SqlExpression rightSql) + { + arguments = new SqlExpression[2] { leftSql, rightSql }; + } + else + { + return null; + } + } + else if (method.IsGenericMethod + && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) { - var argument = methodCallExpression.Arguments[i]; - if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + var enumerable = Visit(methodCallExpression.Arguments[0]); + var item = Visit(methodCallExpression.Arguments[1]); + + if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[1], out var result)) + { + return result; + } + + if (enumerable is SqlExpression sqlEnumerable + && item is SqlExpression sqlItem) + { + arguments = new SqlExpression[2] { sqlEnumerable, sqlItem }; + } + else { return null; } + } + else if (method.Name == nameof(IList.Contains) + && methodCallExpression.Arguments.Count == 1 + && method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any( + t => t == typeof(IList) + || (t.IsGenericType + && t.GetGenericTypeDefinition() == typeof(ICollection<>)))) + { + var enumerable = Visit(methodCallExpression.Object); + var item = Visit(methodCallExpression.Arguments[0]); - arguments[i] = sqlArgument; + if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[0], out var result)) + { + return result; + } + + if (enumerable is SqlExpression sqlEnumerable + && item is SqlExpression sqlItem) + { + sqlObject = sqlEnumerable; + arguments = new SqlExpression[1] { sqlItem }; + } + else + { + return null; + } + } + else + { + if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) + { + return null; + } + + arguments = new SqlExpression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + var argument = methodCallExpression.Arguments[i]; + if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + { + return null; + } + + arguments[i] = sqlArgument; + } } return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); @@ -684,6 +812,197 @@ private static SqlConstantExpression GetConstantOrNull(Expression expression) null) : null; + private bool TryRewriteContainsEntity(Expression source, Expression item, out Expression result) + { + result = null; + + if (!(item is EntityReferenceExpression itemEntityReference)) + { + return false; + } + + var entityType = itemEntityReference.EntityType; + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + var property = primaryKeyProperties[0]; + Expression rewrittenSource; + switch (source) + { + case SqlConstantExpression sqlConstantExpression: + var values = (IEnumerable)sqlConstantExpression.Value; + var propertyValueList = (IList)Activator.CreateInstance(typeof(List<>).MakeGenericType(property.ClrType.MakeNullable())); + var propertyGetter = property.GetGetter(); + foreach (var value in values) + { + propertyValueList.Add(propertyGetter.GetClrValue(value)); + } + + rewrittenSource = Expression.Constant(propertyValueList); + break; + + case SqlParameterExpression sqlParameterExpression + when sqlParameterExpression.Name.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal): + var lambda = Expression.Lambda( + Expression.Call( + _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(sqlParameterExpression.Name, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter + ); + + var newParameterName = + $"{_runtimeParameterPrefix}" + + $"{sqlParameterExpression.Name.Substring(QueryCompilationContext.QueryParameterPrefix.Length)}_{property.Name}"; + + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + break; + + default: + return false; + } + + result = Visit(Expression.Call( + EnumerableMethods.Contains.MakeGenericMethod(property.ClrType.MakeNullable()), + rewrittenSource, + CreatePropertyAccessExpression(item, property))); + + return true; + } + + private bool TryRewriteEntityEquality(ExpressionType nodeType, Expression left, Expression right, out Expression result) + { + var leftEntityReference = left as EntityReferenceExpression; + var rightEntityReference = right as EntityReferenceExpression; + + if (leftEntityReference == null + && rightEntityReference == null) + { + result = null; + return false; + } + + if (IsNullSqlConstantExpression(left) + || IsNullSqlConstantExpression(right)) + { + var nonNullEntityReference = IsNullSqlConstantExpression(left) ? rightEntityReference : leftEntityReference; + var entityType1 = nonNullEntityReference.EntityType; + var primaryKeyProperties1 = entityType1.FindPrimaryKey()?.Properties; + if (primaryKeyProperties1 == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType1.DisplayName())); + } + + result = Visit(primaryKeyProperties1.Select(p => + Expression.MakeBinary( + nodeType, + CreatePropertyAccessExpression(nonNullEntityReference, p), + Expression.Constant(null, p.ClrType.MakeNullable()))) + .Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r))); + + return true; + } + + var leftEntityType = leftEntityReference?.EntityType; + var rightEntityType = rightEntityReference?.EntityType; + var entityType = leftEntityType ?? rightEntityType; + + Debug.Assert(entityType != null, "At least either side should be entityReference so entityType should be non-null."); + + if (leftEntityType != null + && rightEntityType != null + && leftEntityType.GetRootType() != rightEntityType.GetRootType()) + { + result = _sqlExpressionFactory.Constant(false); + return true; + } + + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1 + && (leftEntityReference?.SubqueryEntity != null + || rightEntityReference?.SubqueryEntity != null)) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + result = Visit(primaryKeyProperties.Select(p => + Expression.MakeBinary( + nodeType, + CreatePropertyAccessExpression(left, p), + CreatePropertyAccessExpression(right, p))) + .Aggregate((l, r) => Expression.AndAlso(l, r))); + + return true; + } + + private Expression CreatePropertyAccessExpression(Expression target, IProperty property) + { + switch (target) + { + case SqlConstantExpression sqlConstantExpression: + return Expression.Constant( + property.GetGetter().GetClrValue(sqlConstantExpression.Value), property.ClrType.MakeNullable()); + + case SqlParameterExpression sqlParameterExpression + when sqlParameterExpression.Name.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal): + var lambda = Expression.Lambda( + Expression.Call( + _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(sqlParameterExpression.Name, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter); + + var newParameterName = + $"{_runtimeParameterPrefix}" + + $"{sqlParameterExpression.Name.Substring(QueryCompilationContext.QueryParameterPrefix.Length)}_{property.Name}"; + + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + + case MemberInitExpression memberInitExpression + when memberInitExpression.Bindings.SingleOrDefault( + mb => mb.Member.Name == property.Name) is MemberAssignment memberAssignment: + return memberAssignment.Expression; + + default: + return target.CreateEFPropertyExpression(property); + } + } + + private static T ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) + { + var baseParameter = context.ParameterValues[baseParameterName]; + return baseParameter == null ? (T)(object)null : (T)property.GetGetter().GetClrValue(baseParameter); + } + + private static List ParameterListValueExtractor( + QueryContext context, string baseParameterName, IProperty property) + { + if (!(context.ParameterValues[baseParameterName] is IEnumerable baseListParameter)) + { + return null; + } + + var getter = property.GetGetter(); + return baseListParameter.Select(e => e != null ? (TProperty)getter.GetClrValue(e) : (TProperty)(object)null).ToList(); + } + private static bool CanEvaluate(Expression expression) { #pragma warning disable IDE0066 // Convert switch statement to expression @@ -706,6 +1025,9 @@ private static bool CanEvaluate(Expression expression) } } + private static bool IsNullSqlConstantExpression(Expression expression) + => expression is SqlConstantExpression sqlConstant && sqlConstant.Value == null; + [DebuggerStepThrough] private static bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) { diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 06a24c1c003..f891f02bfa9 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -93,7 +93,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) if (unaryExpression.NodeType == ExpressionType.ArrayLength && unaryExpression.Operand.Type == typeof(byte[])) { - return base.Visit(unaryExpression.Operand) is SqlExpression sqlExpression + return Visit(unaryExpression.Operand) is SqlExpression sqlExpression ? Dependencies.SqlExpressionFactory.Function( "length", new[] { sqlExpression }, diff --git a/src/EFCore/Extensions/Internal/EFPropertyExtensions.cs b/src/EFCore/Extensions/Internal/EFPropertyExtensions.cs deleted file mode 100644 index ea0d5900d9e..00000000000 --- a/src/EFCore/Extensions/Internal/EFPropertyExtensions.cs +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Query.Internal; - -// ReSharper disable once CheckNamespace -namespace Microsoft.EntityFrameworkCore.Internal -{ - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - [DebuggerStepThrough] - // ReSharper disable once InconsistentNaming - public static class EFPropertyExtensions - { - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public static Expression CreateEFPropertyExpression( - [NotNull] this Expression target, - [NotNull] IPropertyBase property, - bool makeNullable = true) - => CreateEFPropertyExpression(target, property.DeclaringType.ClrType, property.ClrType, property.Name, makeNullable); - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public static Expression CreateEFPropertyExpression( - [NotNull] this Expression target, - [NotNull] MemberInfo memberInfo) - => CreateEFPropertyExpression( - target, memberInfo.DeclaringType, memberInfo.GetMemberType(), memberInfo.GetSimpleMemberName(), makeNullable: false); - - private static Expression CreateEFPropertyExpression( - Expression target, - Type propertyDeclaringType, - Type propertyType, - string propertyName, - bool makeNullable) - { - if (propertyDeclaringType != target.Type - && target.Type.IsAssignableFrom(propertyDeclaringType)) - { - target = Expression.Convert(target, propertyDeclaringType); - } - - if (makeNullable) - { - propertyType = propertyType.MakeNullable(); - } - - return Expression.Call( - EF.PropertyMethod.MakeGenericMethod(propertyType), - target, - Expression.Constant(propertyName)); - } - } -} diff --git a/src/EFCore/Infrastructure/ExpressionExtensions.cs b/src/EFCore/Infrastructure/ExpressionExtensions.cs index ab20345221e..e44991206ea 100644 --- a/src/EFCore/Infrastructure/ExpressionExtensions.cs +++ b/src/EFCore/Infrastructure/ExpressionExtensions.cs @@ -301,5 +301,48 @@ public static Expression CreateKeyValueReadExpression( .Select(p => Expression.Convert(target.CreateEFPropertyExpression(p, makeNullable), typeof(object))) .Cast() .ToArray())); + + /// + /// + /// Creates an tree representing EF property access on given expression. + /// + /// + /// This method is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + /// The expression that will be root for generated read operation. + /// The property to access. + /// A value indicating if the value can be nullable. + /// An expression to access EF property on given expression. + public static Expression CreateEFPropertyExpression( + [NotNull] this Expression target, + [NotNull] IPropertyBase property, + bool makeNullable = true) + => CreateEFPropertyExpression(target, property.DeclaringType.ClrType, property.ClrType, property.Name, makeNullable); + + private static Expression CreateEFPropertyExpression( + Expression target, + Type propertyDeclaringType, + Type propertyType, + string propertyName, + bool makeNullable) + { + if (propertyDeclaringType != target.Type + && target.Type.IsAssignableFrom(propertyDeclaringType)) + { + target = Expression.Convert(target, propertyDeclaringType); + } + + if (makeNullable) + { + propertyType = propertyType.MakeNullable(); + } + + return Expression.Call( + EF.PropertyMethod.MakeGenericMethod(propertyType), + target, + Expression.Constant(propertyName)); + } } } diff --git a/src/EFCore/Query/Internal/CompiledQueryBase.cs b/src/EFCore/Query/Internal/CompiledQueryBase.cs index d19315b3aab..8258e825d00 100644 --- a/src/EFCore/Query/Internal/CompiledQueryBase.cs +++ b/src/EFCore/Query/Internal/CompiledQueryBase.cs @@ -68,7 +68,7 @@ protected virtual TResult ExecuteCore( for (var i = 0; i < parameters.Length; i++) { queryContext.AddParameter( - CompiledQueryCache.CompiledQueryParameterPrefix + _queryExpression.Parameters[i + 1].Name, + QueryCompilationContext.QueryParameterPrefix + _queryExpression.Parameters[i + 1].Name, parameters[i]); } @@ -121,7 +121,7 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres return _parameters.Contains(parameterExpression) ? Expression.Parameter( parameterExpression.Type, - CompiledQueryCache.CompiledQueryParameterPrefix + parameterExpression.Name) + QueryCompilationContext.QueryParameterPrefix + parameterExpression.Name) : parameterExpression; } } diff --git a/src/EFCore/Query/Internal/CompiledQueryCache.cs b/src/EFCore/Query/Internal/CompiledQueryCache.cs index 38f9769798a..15313e0195e 100644 --- a/src/EFCore/Query/Internal/CompiledQueryCache.cs +++ b/src/EFCore/Query/Internal/CompiledQueryCache.cs @@ -24,14 +24,6 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal /// public class CompiledQueryCache : ICompiledQueryCache { - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public const string CompiledQueryParameterPrefix = "__"; - private static readonly ConcurrentDictionary _querySyncObjects = new ConcurrentDictionary(); diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs deleted file mode 100644 index 5b5f77ff7fd..00000000000 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ /dev/null @@ -1,1239 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Diagnostics; -using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Internal; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Internal; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.Query.Internal -{ - /// - /// Rewrites comparisons of entities (as opposed to comparisons of their properties) into comparison of their keys. - /// - /// - /// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id). - /// - public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor - { - /// - /// If the entity equality visitors introduces new runtime parameters (because it adds key access over existing parameters), - /// those parameters will have this prefix. - /// - private const string RuntimeParameterPrefix = CompiledQueryCache.CompiledQueryParameterPrefix + "entity_equality_"; - - private readonly QueryCompilationContext _queryCompilationContext; - private readonly IDiagnosticsLogger _logger; - - private static readonly MethodInfo _objectEqualsMethodInfo - = typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) }); - - private static readonly MethodInfo _enumerableContainsMethodInfo = typeof(Enumerable).GetTypeInfo() - .GetDeclaredMethods(nameof(Enumerable.Contains)) - .Single(mi => mi.GetParameters().Length == 2); - - public EntityEqualityRewritingExpressionVisitor([NotNull] QueryCompilationContext queryCompilationContext) - { - _queryCompilationContext = queryCompilationContext; - _logger = queryCompilationContext.Logger; - } - - public virtual Expression Rewrite([NotNull] Expression expression) - { - var result = Visit(expression); - // Work-around for issue#20164 - return new ReducingExpressionVisitor().Visit(result); - } - - private sealed class ReducingExpressionVisitor : ExpressionVisitor - { - protected override Expression VisitExtension(Expression extensionExpression) - { - return extensionExpression is EntityReferenceExpression entityReferenceExpression - ? Visit(entityReferenceExpression.Underlying) - : base.VisitExtension(extensionExpression); - } - } - - protected override Expression VisitNew(NewExpression newExpression) - { - Check.NotNull(newExpression, nameof(newExpression)); - - var visitedArgs = Visit(newExpression.Arguments); - var visitedExpression = newExpression.Update(visitedArgs.Select(Unwrap)); - - // NewExpression.Members is populated for anonymous types, mapping constructor arguments to the properties - // which receive their values. If not populated, a non-anonymous type is being constructed, and we have no idea where - // its constructor arguments will end up. - if (newExpression.Members == null) - { - return visitedExpression; - } - - var entityReferenceInfo = visitedArgs - .Select((a, i) => (Arg: a, Index: i)) - .Where(ai => ai.Arg is EntityReferenceExpression) - .ToDictionary( - ai => visitedExpression.Members[ai.Index].Name, - ai => EntityOrDtoType.FromEntityReferenceExpression((EntityReferenceExpression)ai.Arg)); - - return entityReferenceInfo.Count == 0 - ? (Expression)visitedExpression - : new EntityReferenceExpression(visitedExpression, entityReferenceInfo); - } - - protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) - { - Check.NotNull(memberInitExpression, nameof(memberInitExpression)); - - var visitedNew = Visit(memberInitExpression.NewExpression); - var (visitedBindings, entityReferenceInfo) = VisitMemberBindings(memberInitExpression.Bindings); - var visitedMemberInit = memberInitExpression.Update((NewExpression)Unwrap(visitedNew), visitedBindings); - - return entityReferenceInfo == null - ? (Expression)visitedMemberInit - : new EntityReferenceExpression(visitedMemberInit, entityReferenceInfo); - - // Visits member bindings, unwrapping expressions and surfacing entity reference information via the dictionary - (IEnumerable, Dictionary) VisitMemberBindings( - ReadOnlyCollection bindings) - { - var newBindings = new MemberBinding[bindings.Count]; - Dictionary bindingEntityReferenceInfo = null; - - for (var i = 0; i < bindings.Count; i++) - { - switch (bindings[i]) - { - case MemberAssignment assignment: - var visitedAssignment = VisitMemberAssignment(assignment); - if (visitedAssignment.Expression is EntityReferenceExpression ere) - { - if (bindingEntityReferenceInfo == null) - { - bindingEntityReferenceInfo = new Dictionary(); - } - - bindingEntityReferenceInfo[assignment.Member.Name] = EntityOrDtoType.FromEntityReferenceExpression(ere); - } - - newBindings[i] = assignment.Update(Unwrap(visitedAssignment.Expression)); - continue; - - default: - newBindings[i] = VisitMemberBinding(bindings[i]); - continue; - } - } - - return (newBindings, bindingEntityReferenceInfo); - } - } - - // Note that we could bubble up entity type information from the expressions initializing the array. However, EF Core doesn't - // actually support doing much further with this array, so it's not worth the complexity (right now). So we simply unwrap. - protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) - { - Check.NotNull(newArrayExpression, nameof(newArrayExpression)); - - return newArrayExpression.Update(Visit(newArrayExpression.Expressions).Select(Unwrap)); - } - - // Note that we could bubble up entity type information from the expressions initializing the list. However, EF Core doesn't - // actually support doing much further with this list, so it's not worth the complexity (right now). So we simply unwrap. - protected override Expression VisitListInit(ListInitExpression listInitExpression) - { - Check.NotNull(listInitExpression, nameof(listInitExpression)); - - return listInitExpression.Update( - (NewExpression)Unwrap(listInitExpression.NewExpression), - listInitExpression.Initializers.Select(VisitElementInit)); - } - - protected override ElementInit VisitElementInit(ElementInit elementInit) - { - Check.NotNull(elementInit, nameof(elementInit)); - - return Expression.ElementInit(elementInit.AddMethod, Visit(elementInit.Arguments).Select(Unwrap)); - } - - protected override Expression VisitMember(MemberExpression memberExpression) - { - Check.NotNull(memberExpression, nameof(memberExpression)); - - var visitedExpression = base.Visit(memberExpression.Expression); - var visitedMemberExpression = memberExpression.Update(Unwrap(visitedExpression)); - return visitedExpression is EntityReferenceExpression entityWrapper - ? entityWrapper.TraverseProperty(memberExpression.Member.Name, visitedMemberExpression) - : visitedMemberExpression; - } - - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - Check.NotNull(binaryExpression, nameof(binaryExpression)); - - var (newLeft, newRight) = (Visit(binaryExpression.Left), Visit(binaryExpression.Right)); - if (binaryExpression.NodeType == ExpressionType.Equal - || binaryExpression.NodeType == ExpressionType.NotEqual) - { - if (RewriteEquality(binaryExpression.NodeType == ExpressionType.Equal, newLeft, newRight) is Expression result) - { - return result; - } - } - - return binaryExpression.Update(Unwrap(newLeft), binaryExpression.Conversion, Unwrap(newRight)); - } - - protected override Expression VisitUnary(UnaryExpression unaryExpression) - { - Check.NotNull(unaryExpression, nameof(unaryExpression)); - - var newOperand = Visit(unaryExpression.Operand); - var newUnary = unaryExpression.Update(Unwrap(newOperand)); - - if (unaryExpression.NodeType == ExpressionType.Convert) - { - if (!(newOperand is EntityReferenceExpression sourceWrapper) - || sourceWrapper.EntityType == null) - { - return newUnary; - } - - var castType = unaryExpression.Type; - var castEntityType = sourceWrapper.EntityType.GetTypesInHierarchy().FirstOrDefault(et => et.ClrType == castType); - if (castEntityType == null) - { - return newUnary; - } - - return new EntityReferenceExpression(newUnary, castEntityType); - } - - return newOperand is EntityReferenceExpression entityWrapper - ? entityWrapper.Update(newUnary) - : (Expression)newUnary; - } - - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) - { - Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); - - // This is for "x is y" - var visitedExpression = Visit(typeBinaryExpression.Expression); - var visitedTypeBinary = typeBinaryExpression.Update(Unwrap(visitedExpression)); - return visitedExpression is EntityReferenceExpression entityWrapper - ? entityWrapper.Update(visitedTypeBinary) - : (Expression)visitedTypeBinary; - } - - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) - { - Check.NotNull(conditionalExpression, nameof(conditionalExpression)); - - var newTest = Visit(conditionalExpression.Test); - var newIfTrue = Visit(conditionalExpression.IfTrue); - var newIfFalse = Visit(conditionalExpression.IfFalse); - - var newConditional = conditionalExpression.Update(Unwrap(newTest), Unwrap(newIfTrue), Unwrap(newIfFalse)); - - // TODO: the true and false sides may refer different entity types which happen to have the same - // CLR type (e.g. shared entities) - var wrapper = newIfTrue as EntityReferenceExpression ?? newIfFalse as EntityReferenceExpression; - - return wrapper == null ? (Expression)newConditional : wrapper.Update(newConditional); - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - var method = methodCallExpression.Method; - var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; - var arguments = methodCallExpression.Arguments; - Expression newSource; - - // Check if this is this Equals() - if (method.Name == nameof(object.Equals) - && methodCallExpression.Object != null - && methodCallExpression.Arguments.Count == 1) - { - var (newLeft, newRight) = (Visit(methodCallExpression.Object), Visit(arguments[0])); - return RewriteEquality(true, newLeft, newRight) - ?? methodCallExpression.Update(Unwrap(newLeft), new[] { Unwrap(newRight) }); - } - - if (method.Equals(_objectEqualsMethodInfo)) - { - var (newLeft, newRight) = (Visit(arguments[0]), Visit(arguments[1])); - return RewriteEquality(true, newLeft, newRight) - ?? methodCallExpression.Update(null, new[] { Unwrap(newLeft), Unwrap(newRight) }); - } - - // Navigation via EF.Property() - if (methodCallExpression.TryGetEFPropertyArguments(out _, out var propertyName)) - { - newSource = Visit(arguments[0]); - var newMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource), arguments[1] }); - return newSource is EntityReferenceExpression entityWrapper - ? entityWrapper.TraverseProperty(propertyName, newMethodCall) - : newMethodCall; - } - - // Navigation via an indexer property - if (methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model, out _, out propertyName)) - { - newSource = Visit(methodCallExpression.Object); - var newMethodCall = methodCallExpression.Update(Unwrap(newSource), new[] { arguments[0] }); - return newSource is EntityReferenceExpression entityWrapper - ? entityWrapper.TraverseProperty(propertyName, newMethodCall) - : newMethodCall; - } - - switch (method.Name) - { - // These are methods that require special handling - case nameof(Queryable.Contains) - when genericMethod == QueryableMethods.Contains: - return VisitContainsMethodCall(methodCallExpression); - - case nameof(Queryable.OrderBy) - when genericMethod == QueryableMethods.OrderBy: - case nameof(Queryable.OrderByDescending) - when genericMethod == QueryableMethods.OrderByDescending: - case nameof(Queryable.ThenBy) - when genericMethod == QueryableMethods.ThenBy: - case nameof(Queryable.ThenByDescending) - when genericMethod == QueryableMethods.ThenByDescending: - return VisitOrderingMethodCall(methodCallExpression); - - // The following are projecting methods, which flow the entity type from *within* the lambda outside. - case nameof(Queryable.Select) - when genericMethod == QueryableMethods.Select: - case nameof(Queryable.SelectMany) - when genericMethod == QueryableMethods.SelectManyWithoutCollectionSelector - || genericMethod == QueryableMethods.SelectManyWithCollectionSelector: - return VisitSelectMethodCall(methodCallExpression); - - case nameof(Queryable.GroupJoin) - when genericMethod == QueryableMethods.GroupJoin: - case nameof(Queryable.Join) - when genericMethod == QueryableMethods.Join: - case nameof(QueryableExtensions.LeftJoin) - when genericMethod == QueryableExtensions.LeftJoinMethodInfo: - return VisitJoinMethodCall(methodCallExpression); - - case nameof(Queryable.OfType) - when genericMethod == QueryableMethods.OfType: - return VisitOfType(methodCallExpression); - - case nameof(Queryable.GroupBy) - when genericMethod == QueryableMethods.GroupByWithKeySelector - || genericMethod == QueryableMethods.GroupByWithKeyElementSelector - || genericMethod == QueryableMethods.GroupByWithKeyResultSelector - || genericMethod == QueryableMethods.GroupByWithKeyElementResultSelector: - break; // TODO: Implement - } - - // We handled the Contains Queryable extension method above, but there's also IList.Contains - if (genericMethod == _enumerableContainsMethodInfo - || method.DeclaringType.GetInterfaces().Contains(typeof(IList)) && string.Equals(method.Name, nameof(IList.Contains))) - { - return VisitContainsMethodCall(methodCallExpression); - } - - // TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type. - // Do this here, since below we visit the arguments (avoid double visitation) - - if (arguments.Count == 0) - { - return methodCallExpression.Update( - Unwrap(Visit(methodCallExpression.Object)), Array.Empty()); - } - - // Methods with a typed first argument (source), and with no lambda arguments or a single lambda - // argument that has one parameter are rewritten automatically (e.g. Where(), Average() - var newArguments = new Expression[arguments.Count]; - var lambdaArgs = arguments.Select(a => a.GetLambdaOrNull()).Where(l => l != null).ToArray(); - newSource = Visit(arguments[0]); - newArguments[0] = Unwrap(newSource); - if (methodCallExpression.Object == null - && newSource is EntityReferenceExpression newSourceWrapper - && (lambdaArgs.Length == 0 - || lambdaArgs.Length == 1 && lambdaArgs[0].Parameters.Count == 1)) - { - for (var i = 1; i < arguments.Count; i++) - { - // Visit all arguments, rewriting the single lambda to replace its parameter expression - newArguments[i] = arguments[i].GetLambdaOrNull() is LambdaExpression lambda - ? Unwrap(RewriteAndVisitLambda(lambda, newSourceWrapper)) - : Unwrap(Visit(arguments[i])); - } - - var sourceParamType = methodCallExpression.Method.GetParameters()[0].ParameterType; - var sourceElementType = sourceParamType.TryGetSequenceType(); - if (sourceElementType != null) - { - // If the method returns the element same type as the source, flow the type information - // (e.g. Where) - if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType - && returnElementType == sourceElementType) - { - return newSourceWrapper.Update(methodCallExpression.Update(null, newArguments)); - } - - // If the source type is an IQueryable over the return type, this is a cardinality-reducing method (e.g. First). - // These don't flow the last navigation. In addition, these will be translated into a subquery, and we should not - // perform entity equality rewriting if the entity type has a composite key. - if (methodCallExpression.Method.ReturnType == sourceElementType) - { - return new EntityReferenceExpression( - methodCallExpression.Update(null, newArguments), - newSourceWrapper.EntityType, - lastNavigation: null, - newSourceWrapper.DtoType, - subqueryTraversed: true); - } - } - - // Method does not flow entity type (e.g. Average) - return methodCallExpression.Update(null, newArguments); - } - - // Unknown method - still need to visit all arguments - for (var i = 1; i < arguments.Count; i++) - { - newArguments[i] = Unwrap(Visit(arguments[i])); - } - - return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments); - } - - private Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression) - { - // We handle both Contains the extension method and the instance method - var (newSource, newItem) = methodCallExpression.Arguments.Count == 2 - ? (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]) - : (methodCallExpression.Object, methodCallExpression.Arguments[0]); - (newSource, newItem) = (Visit(newSource), Visit(newItem)); - - var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType; - var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType; - - if (sourceEntityType == null - && itemEntityType == null) - { - return NoTranslation(); - } - - if (sourceEntityType != null - && itemEntityType != null - && sourceEntityType.GetRootType() != itemEntityType.GetRootType()) - { - return Expression.Constant(false); - } - - // One side of the comparison may have an unknown entity type (closure parameter, inline instantiation) - var entityType = sourceEntityType ?? itemEntityType; - - var keyProperties = entityType.FindPrimaryKey()?.Properties; - var keyProperty = keyProperties == null - ? throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())) - : keyProperties.Count == 1 - ? keyProperties[0] - : throw new InvalidOperationException( - CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); - - Expression rewrittenSource, rewrittenItem; - - if (newSource is ConstantExpression listConstant) - { - // The source list is a constant, evaluate and replace with a list of the keys - var listValue = (IEnumerable)listConstant.Value; - var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType.MakeNullable()); - var keyList = (IList)Activator.CreateInstance(keyListType); - var getter = keyProperty.GetGetter(); - foreach (var listItem in listValue) - { - keyList.Add(getter.GetClrValue(listItem)); - } - - rewrittenSource = Expression.Constant(keyList, keyListType); - } - else if (newSource is ParameterExpression listParam - && listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) - { - // The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution. - var lambda = Expression.Lambda( - Expression.Call( - _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(listParam.Name, typeof(string)), - Expression.Constant(keyProperty, typeof(IProperty))), - QueryCompilationContext.QueryContextParameter - ); - - var newParameterName = - $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}"; - rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); - } - else - { - // The source list is neither a constant nor a parameter. Wrap it with a projection to its primary key. - var param = Expression.Parameter(entityType.ClrType, "v"); - var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param); - rewrittenSource = Expression.Call( - QueryableMethods.Select.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()), - Unwrap(newSource), - Expression.Quote(keySelector)); - } - - // Rewrite the item with a key expression as needed (constant, parameter and other are handled within) - rewrittenItem = newItem.IsNullConstantExpression() - ? Expression.Constant(null, entityType.ClrType) - : CreatePropertyAccessExpression(Unwrap(newItem), keyProperty); - - return Expression.Call( - (Unwrap(newSource).Type.IsQueryableType() - ? QueryableMethods.Contains - : _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType.MakeNullable()), - rewrittenSource, - rewrittenItem - ); - - Expression NoTranslation() => methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) }) - : methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) }); - } - - private Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression) - { - var arguments = methodCallExpression.Arguments; - var newSource = Visit(arguments[0]); - - if (!(newSource is EntityReferenceExpression sourceWrapper)) - { - return methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) }); - } - - var newKeySelector = RewriteAndVisitLambda(arguments[1].UnwrapLambdaFromQuote(), sourceWrapper); - - if (!(newKeySelector.Body is EntityReferenceExpression keySelectorWrapper) - || !(keySelectorWrapper.EntityType is IEntityType entityType)) - { - return sourceWrapper.Update( - methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newKeySelector) })); - } - - var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition(); - var firstOrdering = - genericMethodDefinition == QueryableMethods.OrderBy - || genericMethodDefinition == QueryableMethods.OrderByDescending; - var isAscending = - genericMethodDefinition == QueryableMethods.OrderBy - || genericMethodDefinition == QueryableMethods.ThenBy; - - var keyProperties = entityType.FindPrimaryKey()?.Properties; - if (keyProperties == null) - { - throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); - } - - var expression = Unwrap(newSource); - var body = Unwrap(newKeySelector.Body); - var oldParam = newKeySelector.Parameters.Single(); - - foreach (var keyProperty in keyProperties) - { - var param = Expression.Parameter(oldParam.Type, oldParam.Name); - - var rewrittenKeySelector = Expression.Lambda( - ReplacingExpressionVisitor.Replace( - oldParam, param, - CreatePropertyAccessExpression(body, keyProperty)), - param); - - var orderingMethodInfo = GetOrderingMethodInfo(firstOrdering, isAscending); - - expression = Expression.Call( - orderingMethodInfo.MakeGenericMethod(oldParam.Type, keyProperty.ClrType.MakeNullable()), - expression, - Expression.Quote(rewrittenKeySelector) - ); - - firstOrdering = false; - } - - return sourceWrapper.Update(expression); - - static MethodInfo GetOrderingMethodInfo(bool firstOrdering, bool ascending) - { - if (firstOrdering) - { - return ascending - ? QueryableMethods.OrderBy - : QueryableMethods.OrderByDescending; - } - - return ascending - ? QueryableMethods.ThenBy - : QueryableMethods.ThenByDescending; - } - } - - private Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression) - { - var arguments = methodCallExpression.Arguments; - var newSource = Visit(arguments[0]); - - if (!(newSource is EntityReferenceExpression sourceWrapper)) - { - return arguments.Count == 2 - ? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) }) - : arguments.Count == 3 - ? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])), Unwrap(Visit(arguments[2])) }) - : throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name)); - } - - MethodCallExpression newMethodCall; - - if (arguments.Count == 2) - { - var selector = arguments[1].UnwrapLambdaFromQuote(); - var newSelector = RewriteAndVisitLambda(selector, sourceWrapper); - - newMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newSelector) }); - return newSelector.Body is EntityReferenceExpression entityWrapper - ? entityWrapper.Update(newMethodCall) - : (Expression)newMethodCall; - } - - if (arguments.Count == 3) - { - var collectionSelector = arguments[1].UnwrapLambdaFromQuote(); - var newCollectionSelector = RewriteAndVisitLambda(collectionSelector, sourceWrapper); - - var resultSelector = arguments[2].UnwrapLambdaFromQuote(); - var newResultSelector = newCollectionSelector.Body is EntityReferenceExpression newCollectionSelectorWrapper - ? RewriteAndVisitLambda(resultSelector, sourceWrapper, newCollectionSelectorWrapper) - : (LambdaExpression)Visit(resultSelector); - - newMethodCall = methodCallExpression.Update( - null, new[] { Unwrap(newSource), Unwrap(newCollectionSelector), Unwrap(newResultSelector) }); - return newResultSelector.Body is EntityReferenceExpression entityWrapper - ? entityWrapper.Update(newMethodCall) - : (Expression)newMethodCall; - } - - throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name)); - } - - private Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression) - { - var arguments = methodCallExpression.Arguments; - - if (arguments.Count != 5) - { - return base.VisitMethodCall(methodCallExpression); - } - - var newOuter = Visit(arguments[0]); - var newInner = Visit(arguments[1]); - var outerKeySelector = arguments[2].UnwrapLambdaFromQuote(); - var innerKeySelector = arguments[3].UnwrapLambdaFromQuote(); - var resultSelector = arguments[4].UnwrapLambdaFromQuote(); - - if (!(newOuter is EntityReferenceExpression outerWrapper && newInner is EntityReferenceExpression innerWrapper)) - { - return methodCallExpression.Update( - null, - new[] - { - Unwrap(newOuter), - Unwrap(newInner), - Unwrap(Visit(outerKeySelector)), - Unwrap(Visit(innerKeySelector)), - Unwrap(Visit(resultSelector)) - }); - } - - var newOuterKeySelector = RewriteAndVisitLambda(outerKeySelector, outerWrapper); - var newInnerKeySelector = RewriteAndVisitLambda(innerKeySelector, innerWrapper); - var newResultSelector = RewriteAndVisitLambda(resultSelector, outerWrapper, innerWrapper); - - MethodCallExpression newMethodCall; - - // If both outer and inner key selectors project to the same entity type, that's an entity equality - // we need to rewrite. - if (newOuterKeySelector.Body is EntityReferenceExpression outerKeySelectorWrapper - && newInnerKeySelector.Body is EntityReferenceExpression innerKeySelectorWrapper - && outerKeySelectorWrapper.IsEntityType - && innerKeySelectorWrapper.IsEntityType - && outerKeySelectorWrapper.EntityType.GetRootType() == innerKeySelectorWrapper.EntityType.GetRootType()) - { - var entityType = outerKeySelectorWrapper.EntityType; - var keyProperties = entityType.FindPrimaryKey()?.Properties; - if (keyProperties == null) - { - throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); - } - - if (keyProperties.Count > 1 - && (outerKeySelectorWrapper.SubqueryTraversed || innerKeySelectorWrapper.SubqueryTraversed)) - { - // One side of the comparison is the result of a subquery, and we have a composite key. - // Rewriting this would mean evaluating the subquery more than once, so we don't do it. - throw new InvalidOperationException( - CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); - } - - // Rewrite the lambda bodies, adding the key access on top of whatever is there, and then - // produce a new MethodInfo and MethodCallExpression - var origGenericArguments = methodCallExpression.Method.GetGenericArguments(); - - var outerKeyAccessExpression = CreateKeyAccessExpression(Unwrap(outerKeySelectorWrapper), keyProperties); - var outerKeySelectorType = typeof(Func<,>).MakeGenericType(origGenericArguments[0], outerKeyAccessExpression.Type); - newOuterKeySelector = Expression.Lambda( - outerKeySelectorType, - outerKeyAccessExpression, - newOuterKeySelector.TailCall, - newOuterKeySelector.Parameters); - - var innerKeyAccessExpression = CreateKeyAccessExpression(Unwrap(innerKeySelectorWrapper), keyProperties); - var innerKeySelectorType = typeof(Func<,>).MakeGenericType(origGenericArguments[1], innerKeyAccessExpression.Type); - newInnerKeySelector = Expression.Lambda( - innerKeySelectorType, - innerKeyAccessExpression, - newInnerKeySelector.TailCall, - newInnerKeySelector.Parameters); - - var newMethod = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod( - origGenericArguments[0], origGenericArguments[1], outerKeyAccessExpression.Type, origGenericArguments[3]); - - newMethodCall = Expression.Call( - newMethod, - Unwrap(newOuter), Unwrap(newInner), - Expression.Quote(newOuterKeySelector), Expression.Quote(newInnerKeySelector), - Expression.Quote(Unwrap(newResultSelector))); - } - else - { - newMethodCall = methodCallExpression.Update( - null, - new[] - { - Unwrap(newOuter), - Unwrap(newInner), - Expression.Quote(Unwrap(newOuterKeySelector)), - Expression.Quote(Unwrap(newInnerKeySelector)), - Expression.Quote(Unwrap(newResultSelector)) - }); - } - - return newResultSelector.Body is EntityReferenceExpression wrapper - ? wrapper.Update(newMethodCall) - : (Expression)newMethodCall; - } - - private Expression VisitOfType(MethodCallExpression methodCallExpression) - { - var newSource = Visit(methodCallExpression.Arguments[0]); - var updatedMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource) }); - - if (!(newSource is EntityReferenceExpression sourceWrapper) - || sourceWrapper.EntityType == null) - { - return updatedMethodCall; - } - - var castType = methodCallExpression.Type.TryGetSequenceType(); - var castEntityType = sourceWrapper.EntityType.GetTypesInHierarchy().FirstOrDefault(et => et.ClrType == castType); - if (castEntityType == null) - { - return updatedMethodCall; - } - - return new EntityReferenceExpression(updatedMethodCall, castEntityType); - } - - /// - /// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits - /// the lambda's body. - /// - private LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source) - => Expression.Lambda( - lambda.Type, - Visit( - ReplacingExpressionVisitor.Replace( - lambda.Parameters.Single(), - source.Update(lambda.Parameters.Single()), - lambda.Body)), - lambda.TailCall, - lambda.Parameters); - - /// - /// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits - /// the lambda's body. - /// - private LambdaExpression RewriteAndVisitLambda( - LambdaExpression lambda, - EntityReferenceExpression source1, - EntityReferenceExpression source2) - { - Expression original1 = lambda.Parameters[0]; - Expression replacement1 = source1.Update(lambda.Parameters[0]); - Expression original2 = lambda.Parameters[1]; - Expression replacement2 = source2.Update(lambda.Parameters[1]); - return Expression.Lambda( - lambda.Type, - Visit( - new ReplacingExpressionVisitor( - new[] { original1, original2 }, new[] { replacement1, replacement2 }) - .Visit(lambda.Body)), - lambda.TailCall, - lambda.Parameters); - } - - /// - /// Receives already-visited left and right operands of an equality expression and applies entity equality rewriting to them, - /// if possible. - /// - /// The rewritten entity equality expression, or null if rewriting could not occur for some reason. - private Expression RewriteEquality(bool equality, Expression left, Expression right) - { - // TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model. - // TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on. - - var leftTypeWrapper = left as EntityReferenceExpression; - var rightTypeWrapper = right as EntityReferenceExpression; - - // If one of the sides is a DTO, or both sides are unknown, abort - if (leftTypeWrapper == null && rightTypeWrapper == null - || leftTypeWrapper?.IsDtoType == true - || rightTypeWrapper?.IsDtoType == true) - { - return null; - } - - // Handle null constants - if (left.IsNullConstantExpression()) - { - if (right.IsNullConstantExpression()) - { - return equality ? Expression.Constant(true) : Expression.Constant(false); - } - - return rightTypeWrapper?.IsEntityType == true - ? RewriteNullEquality( - equality, rightTypeWrapper.EntityType, rightTypeWrapper.Underlying, rightTypeWrapper.LastNavigation) - : null; - } - - if (right.IsNullConstantExpression()) - { - return leftTypeWrapper?.IsEntityType == true - ? RewriteNullEquality(equality, leftTypeWrapper.EntityType, leftTypeWrapper.Underlying, leftTypeWrapper.LastNavigation) - : null; - } - - if (leftTypeWrapper != null - && rightTypeWrapper != null - && leftTypeWrapper.EntityType.GetRootType() != rightTypeWrapper.EntityType.GetRootType()) - { - return Expression.Constant(!equality); - } - - // One side of the comparison may have an unknown entity type (closure parameter, inline instantiation) - var entityType = (leftTypeWrapper ?? rightTypeWrapper).EntityType; - - return RewriteEntityEquality( - equality, entityType, - Unwrap(left), leftTypeWrapper?.LastNavigation, - Unwrap(right), rightTypeWrapper?.LastNavigation, - leftTypeWrapper?.SubqueryTraversed == true || rightTypeWrapper?.SubqueryTraversed == true); - } - - private Expression RewriteNullEquality( - bool equality, - [NotNull] IEntityType entityType, - [NotNull] Expression nonNullExpression, - [CanBeNull] INavigation lastNavigation) - { - if (lastNavigation?.IsCollection == true) - { - // collection navigation is only null if its parent entity is null (null propagation thru navigation) - // it is probable that user wanted to see if the collection is (not) empty - // log warning suggesting to use Any() instead. - _logger.PossibleUnintendedCollectionNavigationNullComparisonWarning(lastNavigation); - return RewriteNullEquality(equality, lastNavigation.DeclaringEntityType, UnwrapLastNavigation(nonNullExpression), null); - } - - var keyProperties = entityType.FindPrimaryKey()?.Properties; - if (keyProperties == null) - { - throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); - } - - // TODO: bring back foreign key comparison optimization (#15826) - - // When comparing an entity to null, it's sufficient to simply compare its first primary key column to null. - // (this is also why we can do it even over a subquery with a composite key) - return Expression.MakeBinary( - equality ? ExpressionType.Equal : ExpressionType.NotEqual, - CreatePropertyAccessExpression(nonNullExpression, keyProperties[0]), - Expression.Constant(null)); - } - - private Expression RewriteEntityEquality( - bool equality, - [NotNull] IEntityType entityType, - [NotNull] Expression left, [CanBeNull] INavigation leftNavigation, - [NotNull] Expression right, [CanBeNull] INavigation rightNavigation, - bool subqueryTraversed) - { - if (leftNavigation?.IsCollection == true - || rightNavigation?.IsCollection == true) - { - if (leftNavigation?.Equals(rightNavigation) == true) - { - // Log a warning that comparing 2 collections causes reference comparison - _logger.PossibleUnintendedReferenceComparisonWarning(left, right); - return RewriteEntityEquality( - equality, leftNavigation.DeclaringEntityType, - UnwrapLastNavigation(left), null, - UnwrapLastNavigation(right), null, - subqueryTraversed); - } - - return Expression.Constant(!equality); - } - - var keyProperties = entityType.FindPrimaryKey()?.Properties; - if (keyProperties == null) - { - throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); - } - - if (subqueryTraversed && keyProperties.Count > 1) - { - // One side of the comparison is the result of a subquery, and we have a composite key. - // Rewriting this would mean evaluating the subquery more than once, so we don't do it. - throw new InvalidOperationException( - CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); - } - - return Expression.MakeBinary( - equality ? ExpressionType.Equal : ExpressionType.NotEqual, - CreateKeyAccessExpression(Unwrap(left), keyProperties), - CreateKeyAccessExpression(Unwrap(right), keyProperties)); - } - - protected override Expression VisitExtension(Expression extensionExpression) - { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - switch (extensionExpression) - { - // If the expression is an EntityReferenceExpression, simply returns it as all rewriting has already occurred. - // This is necessary when traversing wrapping expressions that have been injected into the lambda for parameters. - case EntityReferenceExpression _: - return extensionExpression; - - case QueryRootExpression queryRootExpression: - return new EntityReferenceExpression(queryRootExpression, queryRootExpression.EntityType); - - default: - return base.VisitExtension(extensionExpression); - } - } - - private Expression CreateKeyAccessExpression( - Expression target, - IReadOnlyList properties) - => properties.Count == 1 - ? CreatePropertyAccessExpression(target, properties[0]) - : Expression.New( - AnonymousObject.AnonymousObjectCtor, - Expression.NewArrayInit( - typeof(object), - properties - .Select(p => Expression.Convert(CreatePropertyAccessExpression(target, p), typeof(object))) - .Cast() - .ToArray())); - - private Expression CreatePropertyAccessExpression(Expression target, IProperty property) - { - // The target is a constant - evaluate the property immediately and return the result - if (target is ConstantExpression constantExpression) - { - return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType.MakeNullable()); - } - - // The target is complex which can be evaluated to Constant. - if (CanEvaluate(target)) - { - var value = Expression.Lambda>(Expression.Convert(target, typeof(object))).Compile().Invoke(); - return Expression.Constant(property.GetGetter().GetClrValue(value), property.ClrType.MakeNullable()); - } - - // If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new - // parameter to be added at runtime, with the value of the property on the base parameter. - if (target is ParameterExpression baseParameterExpression - && baseParameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) - { - // Generate an expression to get the base parameter from the query context's parameter list, and extract the - // property from that - var lambda = Expression.Lambda( - Expression.Call( - _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(baseParameterExpression.Name, typeof(string)), - Expression.Constant(property, typeof(IProperty))), - QueryCompilationContext.QueryContextParameter); - - var newParameterName = - $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}"; - return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); - } - - return target.CreateEFPropertyExpression(property); - } - - private static bool CanEvaluate(Expression expression) - { - switch (expression) - { - case ConstantExpression constantExpression: - return true; - - case NewExpression newExpression: - return newExpression.Arguments.All(e => CanEvaluate(e)); - - case MemberInitExpression memberInitExpression: - return CanEvaluate(memberInitExpression.NewExpression) - && memberInitExpression.Bindings.All( - mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); - - default: - return false; - } - } - - private static T ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) - { - var baseParameter = context.ParameterValues[baseParameterName]; - return baseParameter == null ? (T)(object)null : (T)property.GetGetter().GetClrValue(baseParameter); - } - - private static readonly MethodInfo _parameterValueExtractor - = typeof(EntityEqualityRewritingExpressionVisitor) - .GetTypeInfo() - .GetDeclaredMethod(nameof(ParameterValueExtractor)); - - /// - /// Extracts the list parameter with name from and returns a - /// projection to its elements' values. - /// - private static List ParameterListValueExtractor( - QueryContext context, string baseParameterName, IProperty property) - { - var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable; - if (baseListParameter == null) - { - return null; - } - - var getter = property.GetGetter(); - return baseListParameter.Select(e => (TProperty)getter.GetClrValue(e)).ToList(); - } - - private static readonly MethodInfo _parameterListValueExtractor - = typeof(EntityEqualityRewritingExpressionVisitor) - .GetTypeInfo() - .GetDeclaredMethod(nameof(ParameterListValueExtractor)); - - protected static Expression UnwrapLastNavigation([NotNull] Expression expression) - { - Check.NotNull(expression, nameof(expression)); - - return (expression as MemberExpression)?.Expression - ?? (expression is MethodCallExpression methodCallExpression - && methodCallExpression.Method.IsEFPropertyMethod() - ? methodCallExpression.Arguments[0] - : null); - } - - protected static Expression Unwrap([NotNull] Expression expression) - => expression switch - { - EntityReferenceExpression wrapper => wrapper.Underlying, - LambdaExpression lambda when lambda.Body is EntityReferenceExpression wrapper => - Expression.Lambda( - lambda.Type, - wrapper.Underlying, - lambda.TailCall, - lambda.Parameters), - _ => expression - }; - - protected struct EntityOrDtoType - { - public static EntityOrDtoType FromEntityReferenceExpression(EntityReferenceExpression ere) - => new EntityOrDtoType - { - EntityType = ere.IsEntityType ? ere.EntityType : null, - DtoType = ere.IsDtoType ? ere.DtoType : null - }; - - public static EntityOrDtoType FromDtoType(Dictionary dtoType) - => new EntityOrDtoType { DtoType = dtoType }; - - public bool IsEntityType => EntityType != null; - public bool IsDto => DtoType != null; - - public IEntityType EntityType; - public Dictionary DtoType; - } - - protected class EntityReferenceExpression : Expression - { - public sealed override ExpressionType NodeType => ExpressionType.Extension; - - /// - /// The underlying expression being wrapped. - /// - [NotNull] - public Expression Underlying { get; } - - public override Type Type => Underlying.Type; - - [CanBeNull] - public IEntityType EntityType { get; } - - [CanBeNull] - public INavigation LastNavigation => EntityType == null ? null : _lastNavigation; - - [CanBeNull] - private readonly INavigation _lastNavigation; - - [CanBeNull] - public Dictionary DtoType { get; } - - public bool SubqueryTraversed { get; } - - public bool IsDtoType => DtoType != null; - public bool IsEntityType => EntityType != null; - - public EntityReferenceExpression(Expression underlying, Dictionary dtoType) - { - Underlying = underlying; - DtoType = dtoType; - } - - public EntityReferenceExpression(Expression underlying, IEntityType entityType) - : this(underlying, entityType, subqueryTraversed: false) - { - } - - private EntityReferenceExpression( - Expression underlying, IEntityType entityType, bool subqueryTraversed) - { - Underlying = underlying; - EntityType = entityType; - SubqueryTraversed = subqueryTraversed; - } - - public EntityReferenceExpression( - Expression underlying, - IEntityType entityType, - INavigation lastNavigation, - Dictionary dtoType, - bool subqueryTraversed) - { - Underlying = underlying; - EntityType = entityType; - _lastNavigation = lastNavigation; - DtoType = dtoType; - SubqueryTraversed = subqueryTraversed; - } - - /// - /// Attempts to find as a navigation from the current node, - /// and if successful, returns a new wrapping the - /// given expression. Otherwise returns the given expression without wrapping it. - /// - public virtual Expression TraverseProperty(string propertyName, Expression destinationExpression) - { - if (IsEntityType) - { - return EntityType.FindNavigation(propertyName) is INavigation navigation - ? new EntityReferenceExpression( - destinationExpression, - navigation.TargetEntityType, - navigation, - null, - SubqueryTraversed) - : destinationExpression; - } - - if (IsDtoType) - { - if (DtoType.TryGetValue(propertyName, out var entityOrDto)) - { - return entityOrDto.IsEntityType - ? new EntityReferenceExpression(destinationExpression, entityOrDto.EntityType) - : new EntityReferenceExpression(destinationExpression, entityOrDto.DtoType); - } - - return destinationExpression; - } - - throw new InvalidOperationException(CoreStrings.UnknownEntity("TypeInfo")); - } - - public EntityReferenceExpression Update(Expression newUnderlying) - => new EntityReferenceExpression(newUnderlying, EntityType, null, DtoType, SubqueryTraversed); - - protected override Expression VisitChildren(ExpressionVisitor visitor) - { - Check.NotNull(visitor, nameof(visitor)); - - return Update(visitor.Visit(Underlying)); - } - - public virtual void Print(ExpressionPrinter expressionPrinter) - { - Check.NotNull(expressionPrinter, nameof(expressionPrinter)); - - expressionPrinter.Visit(Underlying); - - if (IsEntityType) - { - expressionPrinter.Append($".EntityType({EntityType})"); - } - else if (IsDtoType) - { - expressionPrinter.Append(".DTO"); - } - - if (SubqueryTraversed) - { - expressionPrinter.Append(".SubqueryTraversed"); - } - } - - public override string ToString() => - $"{Underlying}[{(IsEntityType ? EntityType.ShortName() : "DTO")}{(SubqueryTraversed ? ", Subquery" : "")}]"; - } - } -} diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs index b365edd5959..a5ac5a2985b 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs @@ -9,7 +9,6 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Query.Internal @@ -93,28 +92,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return base.VisitMethodCall(methodCallExpression); } - protected EntityReference UnwrapEntityReference(Expression expression) - { - switch (expression) - { - case EntityReference entityReference: - return entityReference; - - case NavigationTreeExpression navigationTreeExpression: - return UnwrapEntityReference(navigationTreeExpression.Value); - - case NavigationExpansionExpression navigationExpansionExpression - when navigationExpansionExpression.CardinalityReducingGenericMethodInfo != null: - return UnwrapEntityReference(navigationExpansionExpression.PendingSelector); - - case OwnedNavigationReference ownedNavigationReference: - return ownedNavigationReference.EntityReference; - - default: - return null; - } - } - private Expression TryExpandNavigation(Expression root, MemberIdentity memberIdentity) { var innerExpression = root.UnwrapTypeConversion(out var convertedType); @@ -699,5 +676,188 @@ protected override Expression VisitExtension(Expression extensionExpression) : base.VisitExtension(extensionExpression); } } + + private sealed class RemoveRedundantNavigationComparisonExpressionVisitor : ExpressionVisitor + { + private readonly IDiagnosticsLogger _logger; + + public RemoveRedundantNavigationComparisonExpressionVisitor(IDiagnosticsLogger logger) + { + _logger = logger; + } + + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + if (binaryExpression.NodeType == ExpressionType.Equal + || binaryExpression.NodeType == ExpressionType.NotEqual) + { + var left = ProcessNavigationPath(binaryExpression.Left); + var right = ProcessNavigationPath(binaryExpression.Right); + + if (TryRemoveNavigationComparison(binaryExpression.NodeType, left, right, out var result)) + { + return result; + } + } + + return base.VisitBinary(binaryExpression); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + var method = methodCallExpression.Method; + if (method.Name == nameof(object.Equals) + && methodCallExpression.Object != null + && methodCallExpression.Arguments.Count == 1) + { + var left = ProcessNavigationPath(methodCallExpression.Object); + var right = ProcessNavigationPath(methodCallExpression.Arguments[0]); + + if (TryRemoveNavigationComparison(ExpressionType.Equal, left, right, out var result)) + { + return result; + } + } + else if (method.Name == nameof(object.Equals) + && methodCallExpression.Object == null + && methodCallExpression.Arguments.Count == 2) + { + var left = ProcessNavigationPath(methodCallExpression.Arguments[0]); + var right = ProcessNavigationPath(methodCallExpression.Arguments[1]); + + if (TryRemoveNavigationComparison(ExpressionType.Equal, left, right, out var result)) + { + return result; + } + } + + return base.VisitMethodCall(methodCallExpression); + } + + private bool TryRemoveNavigationComparison(ExpressionType nodeType, Expression left, Expression right, out Expression result) + { + result = null; + var leftNavigationData = left as NavigationDataExpression; + var rightNavigationData = right as NavigationDataExpression; + + if (leftNavigationData == null + && rightNavigationData == null) + { + return false; + } + + if (left.IsNullConstantExpression() + || right.IsNullConstantExpression()) + { + var nonNullNavigationData = left.IsNullConstantExpression() + ? rightNavigationData + : leftNavigationData; + + if (nonNullNavigationData.Navigation?.IsCollection == true) + { + _logger.PossibleUnintendedCollectionNavigationNullComparisonWarning(nonNullNavigationData.Navigation); + + result = Expression.MakeBinary( + nodeType, nonNullNavigationData.Inner.Current, Expression.Constant(null, nonNullNavigationData.Inner.Type)); + + return true; + } + } + else if (leftNavigationData != null + && rightNavigationData != null) + { + if (leftNavigationData.Navigation?.IsCollection == true) + { + if (leftNavigationData.Navigation == rightNavigationData.Navigation) + { + _logger.PossibleUnintendedReferenceComparisonWarning(leftNavigationData.Current, rightNavigationData.Current); + + result = Expression.MakeBinary(nodeType, leftNavigationData.Inner.Current, rightNavigationData.Inner.Current); + } + else + { + result = Expression.Constant(nodeType == ExpressionType.NotEqual); + } + + return true; + } + } + + return false; + } + + private Expression ProcessNavigationPath(Expression expression) + { + switch (expression) + { + case MemberExpression memberExpression: + var innerExpression = ProcessNavigationPath(memberExpression.Expression); + if (innerExpression is NavigationDataExpression navigationDataExpression + && navigationDataExpression.EntityType != null) + { + var navigation = navigationDataExpression.EntityType.FindNavigation(memberExpression.Member); + if (navigation != null) + { + return new NavigationDataExpression(expression, navigationDataExpression, navigation); + } + } + + return expression; + + case MethodCallExpression methodCallExpression + when methodCallExpression.TryGetEFPropertyArguments(out var source, out var navigationName): + return expression; + + default: + var convertlessExpression = expression.UnwrapTypeConversion(out var convertedType); + if (UnwrapEntityReference(convertlessExpression) is EntityReference entityReference) + { + var entityType = entityReference.EntityType; + if (convertedType != null) + { + entityType = entityType.GetTypesInHierarchy() + .FirstOrDefault(et => et.ClrType == convertedType); + if (entityType == null) + { + return expression; + } + } + + return new NavigationDataExpression(expression, entityType); + } + + return expression; + } + } + + private sealed class NavigationDataExpression : Expression + { + public NavigationDataExpression(Expression current, IEntityType entityType) + { + Navigation = default; + Current = current; + EntityType = entityType; + } + + public NavigationDataExpression(Expression current, NavigationDataExpression inner, INavigation navigation) + { + Current = current; + Inner = inner; + Navigation = navigation; + if (!navigation.IsCollection) + { + EntityType = navigation.TargetEntityType; + } + } + + public override Type Type => Current.Type; + public override ExpressionType NodeType => ExpressionType.Extension; + + public INavigation Navigation { get; } + public Expression Current { get; } + public NavigationDataExpression Inner { get; } + public IEntityType EntityType { get; } + } + } } } diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index f80e4c47d02..8eeb03b370d 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -51,8 +51,8 @@ private static readonly PropertyInfo _queryContextContextPropertyInfo private readonly SubqueryMemberPushdownExpressionVisitor _subqueryMemberPushdownExpressionVisitor; private readonly ReducingExpressionVisitor _reducingExpressionVisitor; private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor; + private readonly RemoveRedundantNavigationComparisonExpressionVisitor _removeRedundantNavigationComparisonExpressionVisitor; private readonly ISet _parameterNames = new HashSet(); - private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor; private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor; private readonly Dictionary _parameterizedQueryFilterPredicateCache @@ -73,7 +73,8 @@ public NavigationExpandingExpressionVisitor( _subqueryMemberPushdownExpressionVisitor = new SubqueryMemberPushdownExpressionVisitor(queryCompilationContext.Model); _reducingExpressionVisitor = new ReducingExpressionVisitor(); _entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor(); - _entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext); + _removeRedundantNavigationComparisonExpressionVisitor = new RemoveRedundantNavigationComparisonExpressionVisitor( + queryCompilationContext.Logger); _parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor( evaluatableExpressionFilter, _parameters, @@ -852,8 +853,10 @@ private NavigationExpansionExpression ProcessInclude(NavigationExpansionExpressi filterExpression = Expression.Lambda(prm, prm); } - var arguments = new List(); - arguments.Add(filterExpression.Body); + var arguments = new List + { + filterExpression.Body + }; arguments.AddRange(methodCallExpression.Arguments.Skip(1)); filterExpression = Expression.Lambda( methodCallExpression.Update(methodCallExpression.Object, arguments), @@ -900,8 +903,7 @@ private NavigationExpansionExpression ProcessJoin( ApplyPendingOrderings(innerSource); } - outerKeySelector = ProcessLambdaExpression(outerSource, outerKeySelector); - innerKeySelector = ProcessLambdaExpression(innerSource, innerKeySelector); + (outerKeySelector, innerKeySelector) = ProcessJoinConditions(outerSource, innerSource, outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( outerSource.SourceElementType, innerSource.SourceElementType); @@ -949,8 +951,7 @@ private NavigationExpansionExpression ProcessLeftJoin( ApplyPendingOrderings(innerSource); } - outerKeySelector = ProcessLambdaExpression(outerSource, outerKeySelector); - innerKeySelector = ProcessLambdaExpression(innerSource, innerKeySelector); + (outerKeySelector, innerKeySelector) = ProcessJoinConditions(outerSource, innerSource, outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( outerSource.SourceElementType, innerSource.SourceElementType); @@ -1203,6 +1204,63 @@ private void ApplyPendingOrderings(NavigationExpansionExpression source) var lambdaBody = Visit(keySelector); lambdaBody = _pendingSelectorExpandingExpressionVisitor.Visit(lambdaBody); + if (lambdaBody is NavigationTreeExpression navigationTreeExpression + && navigationTreeExpression.Value is EntityReference entityReference) + { + var primaryKeyProperties = entityReference.EntityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties != null) + { + for (var i = 0; i < primaryKeyProperties.Count; i++) + { + var genericMethod = i > 0 + ? GetThenByMethod(orderingMethod) + : orderingMethod; + + var keyPropertyLambda = GenerateLambda( + navigationTreeExpression.CreateEFPropertyExpression(primaryKeyProperties[i], entityReference.IsOptional), + source.CurrentParameter); + + source.UpdateSource( + Expression.Call( + genericMethod.MakeGenericMethod(source.SourceElementType, keyPropertyLambda.ReturnType), + source.Source, + keyPropertyLambda)); + } + + continue; + } + } + + if (lambdaBody is NavigationExpansionExpression navigationExpansionExpression + && navigationExpansionExpression.CardinalityReducingGenericMethodInfo != null + && navigationExpansionExpression.PendingSelector is NavigationTreeExpression subqueryNavigationTreeExpression + && subqueryNavigationTreeExpression.Value is EntityReference subqueryEntityReference) + { + var primaryKeyProperties = subqueryEntityReference.EntityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties != null) + { + for (var i = 0; i < primaryKeyProperties.Count; i++) + { + var genericMethod = i > 0 + ? GetThenByMethod(orderingMethod) + : orderingMethod; + + var keyPropertyLambda = GenerateLambda( + navigationExpansionExpression.CreateEFPropertyExpression( + primaryKeyProperties[i], subqueryEntityReference.IsOptional), + source.CurrentParameter); + + source.UpdateSource( + Expression.Call( + genericMethod.MakeGenericMethod(source.SourceElementType, keyPropertyLambda.ReturnType), + source.Source, + keyPropertyLambda)); + } + + continue; + } + } + var keySelectorLambda = GenerateLambda(lambdaBody, source.CurrentParameter); source.UpdateSource( @@ -1214,6 +1272,47 @@ private void ApplyPendingOrderings(NavigationExpansionExpression source) source.ClearPendingOrderings(); } + + static MethodInfo GetThenByMethod(MethodInfo currentGenericMethod) + => currentGenericMethod == QueryableMethods.OrderBy + ? QueryableMethods.ThenBy + : currentGenericMethod == QueryableMethods.OrderByDescending + ? QueryableMethods.ThenByDescending + : currentGenericMethod; + } + + private (LambdaExpression, LambdaExpression) ProcessJoinConditions( + NavigationExpansionExpression outerSource, + NavigationExpansionExpression innerSource, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector) + { + var outerKeyLambda = RemapLambdaExpression(outerSource, outerKeySelector); + var innerKeyLambda = RemapLambdaExpression(innerSource, innerKeySelector); + + var keyComparison = (BinaryExpression)_removeRedundantNavigationComparisonExpressionVisitor + .Visit(Expression.Equal(outerKeyLambda, innerKeyLambda)); + + outerKeySelector = GenerateLambda(ExpandNavigationsForSource(outerSource, keyComparison.Left), outerSource.CurrentParameter); + innerKeySelector = GenerateLambda(ExpandNavigationsForSource(innerSource, keyComparison.Right), innerSource.CurrentParameter); + + if (outerKeySelector.ReturnType != innerKeySelector.ReturnType) + { + var baseType = outerKeySelector.ReturnType.IsAssignableFrom(innerKeySelector.ReturnType) + ? outerKeySelector.ReturnType + : innerKeySelector.ReturnType; + + outerKeySelector = ChangeReturnType(outerKeySelector, baseType); + innerKeySelector = ChangeReturnType(innerKeySelector, baseType); + } + + return (outerKeySelector, innerKeySelector); + + static LambdaExpression ChangeReturnType(LambdaExpression lambdaExpression, Type type) + { + var delegateType = typeof(Func<,>).MakeGenericType(lambdaExpression.Parameters[0].Type, type); + return Expression.Lambda(delegateType, lambdaExpression.Body, lambdaExpression.Parameters); + } } private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpansionExpression) @@ -1238,8 +1337,7 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa QueryableMethods.Where.MakeGenericMethod(rootEntityType.ClrType), new QueryRootExpression(rootEntityType), filterPredicate); - var rewrittenFilterWrapper = (MethodCallExpression)_entityEqualityRewritingExpressionVisitor.Rewrite(filterWrapper); - filterPredicate = rewrittenFilterWrapper.Arguments[1].UnwrapLambdaFromQuote(); + filterPredicate = filterWrapper.Arguments[1].UnwrapLambdaFromQuote(); _parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate; } @@ -1446,6 +1544,7 @@ private NavigationExpansionExpression CreateNavigationExpansionExpression( private Expression ExpandNavigationsForSource(NavigationExpansionExpression source, Expression expression) { + expression = _removeRedundantNavigationComparisonExpressionVisitor.Visit(expression); expression = new ExpandingExpressionVisitor(this, source).Visit(expression); expression = _subqueryMemberPushdownExpressionVisitor.Visit(expression); expression = Visit(expression); @@ -1609,6 +1708,28 @@ private Expression SnapshotExpression(Expression selector) } } + private static EntityReference UnwrapEntityReference(Expression expression) + { + switch (expression) + { + case EntityReference entityReference: + return entityReference; + + case NavigationTreeExpression navigationTreeExpression: + return UnwrapEntityReference(navigationTreeExpression.Value); + + case NavigationExpansionExpression navigationExpansionExpression + when navigationExpansionExpression.CardinalityReducingGenericMethodInfo != null: + return UnwrapEntityReference(navigationExpansionExpression.PendingSelector); + + case OwnedNavigationReference ownedNavigationReference: + return ownedNavigationReference.EntityReference; + + default: + return null; + } + } + private sealed class Parameters : IParameterValues { private readonly IDictionary _parameterValues = new Dictionary(); diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs index 9e2fa207dbd..aa5c1581773 100644 --- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs @@ -316,7 +316,7 @@ var compilerPrefixIndex } parameterName - = CompiledQueryCache.CompiledQueryParameterPrefix + = QueryCompilationContext.QueryParameterPrefix + parameterName + "_" + _parameterValues.ParameterValues.Count; diff --git a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs index 7eb57326d35..b9717f58637 100644 --- a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs @@ -256,7 +256,7 @@ private static bool ClientSource(Expression expression) || expression is MemberInitExpression || expression is NewExpression || expression is ParameterExpression parameter - && parameter.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal); + && parameter.Name.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal); private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType) { diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs index 5d9a41c02ce..af6206b50c5 100644 --- a/src/EFCore/Query/QueryCompilationContext.cs +++ b/src/EFCore/Query/QueryCompilationContext.cs @@ -16,6 +16,26 @@ namespace Microsoft.EntityFrameworkCore.Query { public class QueryCompilationContext { + /// + /// + /// Prefix for all the query parameters generated during parameter extraction in query pipeline. + /// + /// + /// This property is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + public const string QueryParameterPrefix = "__"; + + /// + /// + /// ParameterExpression representing parameter in query expression. + /// + /// + /// This property is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// public static readonly ParameterExpression QueryContextParameter = Expression.Parameter(typeof(QueryContext), "queryContext"); private readonly IQueryTranslationPreprocessorFactory _queryTranslationPreprocessorFactory; @@ -23,11 +43,6 @@ public class QueryCompilationContext private readonly IQueryTranslationPostprocessorFactory _queryTranslationPostprocessorFactory; private readonly IShapedQueryCompilingExpressionVisitorFactory _shapedQueryCompilingExpressionVisitorFactory; - /// - /// A dictionary mapping parameter names to lambdas that, given a QueryContext, can extract that parameter's value. - /// This is needed for cases where we need to introduce a parameter during the compilation phase (e.g. entity equality rewrites - /// a parameter to an ID property on that parameter). - /// private Dictionary _runtimeParameters; public QueryCompilationContext( diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs index 3b35e7bd29e..fcaed1b632f 100644 --- a/src/EFCore/Query/QueryTranslationPreprocessor.cs +++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs @@ -36,7 +36,6 @@ public virtual Expression Process([NotNull] Expression query) query = new VBToCSharpConvertingExpressionVisitor().Visit(query); query = new AllAnyContainsRewritingExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); - query = new EntityEqualityRewritingExpressionVisitor(QueryCompilationContext).Rewrite(query); query = new SubqueryMemberPushdownExpressionVisitor(QueryCompilationContext.Model).Visit(query); query = new NavigationExpandingExpressionVisitor(this, QueryCompilationContext, Dependencies.EvaluatableExpressionFilter) .Expand(query); diff --git a/src/EFCore/Query/ReplacingExpressionVisitor.cs b/src/EFCore/Query/ReplacingExpressionVisitor.cs index 4c8fc0bc456..c47e63f8a4b 100644 --- a/src/EFCore/Query/ReplacingExpressionVisitor.cs +++ b/src/EFCore/Query/ReplacingExpressionVisitor.cs @@ -6,7 +6,6 @@ using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Query diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs index 6c1a168bdd1..551870c001a 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs @@ -3844,6 +3844,18 @@ public override Task Entity_equality_orderby_descending_composite_key(bool async return base.Entity_equality_orderby_descending_composite_key(async); } + [ConditionalTheory(Skip = "Issue #17246")] + public override Task Entity_equality_orderby_subquery(bool async) + { + return base.Entity_equality_orderby_subquery(async); + } + + [ConditionalTheory(Skip = "Issue #17246")] + public override Task Entity_equality_orderby_descending_subquery_composite_key(bool async) + { + return base.Entity_equality_orderby_descending_subquery_composite_key(async); + } + [ConditionalTheory(Skip = "Issue #17246")] public override Task Null_Coalesce_Short_Circuit(bool async) { @@ -4079,6 +4091,28 @@ FROM root c WHERE (c[""Discriminator""] = ""OrderDetail"")"); } + public override async Task Entity_equality_with_null_coalesce_client_side(bool async) + { + await base.Entity_equality_with_null_coalesce_client_side(async); + + AssertSql( + @"@__entity_equality_p_0_CustomerID='ALFKI' + +SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = @__entity_equality_p_0_CustomerID))"); + } + + public override async Task Entity_equality_contains_with_list_of_null(bool async) + { + await base.Entity_equality_contains_with_list_of_null(async); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] IN (""ALFKI"") OR (c[""CustomerID""] = null)))"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.InMemory.FunctionalTests/Query/NorthwindAggregateOperatorsQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/NorthwindAggregateOperatorsQueryInMemoryTest.cs index 67ec3f78109..6ffb0a6463e 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/NorthwindAggregateOperatorsQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/NorthwindAggregateOperatorsQueryInMemoryTest.cs @@ -60,5 +60,11 @@ public override Task LastOrDefault_when_no_order_by(bool async) { return base.LastOrDefault_when_no_order_by(async); } + + [ConditionalFact(Skip = "Issue#20023")] + public override void Contains_over_keyless_entity_throws() + { + base.Contains_over_keyless_entity_throws(); + } } } diff --git a/test/EFCore.Specification.Tests/LazyLoadProxyTestBase.cs b/test/EFCore.Specification.Tests/LazyLoadProxyTestBase.cs index 22ef4c27f6e..818766aa6f8 100644 --- a/test/EFCore.Specification.Tests/LazyLoadProxyTestBase.cs +++ b/test/EFCore.Specification.Tests/LazyLoadProxyTestBase.cs @@ -15,6 +15,7 @@ using Microsoft.EntityFrameworkCore.TestUtilities; using Microsoft.Extensions.DependencyInjection; using Xunit; +using Xunit.Sdk; // ReSharper disable InconsistentNaming namespace Microsoft.EntityFrameworkCore @@ -2169,6 +2170,24 @@ public virtual void Top_level_projection_track_entities_before_passing_to_client Assert.NotNull(((dynamic)query).Single); } + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public virtual async Task Entity_equality_with_proxy_parameter(bool async) + { + using var context = CreateContext(lazyLoadingEnabled: true); + var called = context.Set().FirstOrDefault(); + ClearLog(); + + var query = from Child q in context.Set() + where q.Parent == called + select q; + + var result = async ? await query.ToListAsync() : query.ToList(); + + RecordLog(); + } + private static class DtoFactory { public static object CreateDto(Parent parent) diff --git a/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs index e58907bfd72..f69b45002d3 100644 --- a/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs @@ -1930,7 +1930,7 @@ from g2 in grouping.DefaultIfEmpty() expectedIncludes); } - [ConditionalTheory] + [ConditionalTheory(Skip = "Issue#15783")] [MemberData(nameof(IsAsyncData))] public virtual Task Include_on_GroupJoin_SelectMany_DefaultIfEmpty_with_conditional_result(bool async) { @@ -5026,31 +5026,24 @@ from w in grouping.DefaultIfEmpty() [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual async Task Join_with_complex_key_selector(bool async) + public virtual Task Join_with_complex_key_selector(bool async) { - var message = (await Assert.ThrowsAsync( - () => AssertQuery( - async, - ss => ss.Set() - .Join( - ss.Set().Where(t => t.Note == "Marcus' Tag"), o => true, i => true, (o, i) => new { o, i }) - .GroupJoin( - ss.Set(), - oo => oo.o.Members.FirstOrDefault(v => v.Tag == oo.i), - ii => ii, - (k, g) => new - { - k.o, - k.i, - value = g.OrderBy(gg => gg.FullName).FirstOrDefault() - }) - .Select( - r => new { r.o.Id, TagId = r.i.Id }), - elementSorter: e => (e.Id, e.TagId)))).Message; - - Assert.Equal( - "This query would cause multiple evaluation of a subquery because entity 'Gear' has a composite key. Rewrite your query avoiding the subquery.", - message); + return AssertTranslationFailed(() => AssertQuery( + async, + ss => ss.Set() + .Join(ss.Set().Where(t => t.Note == "Marcus' Tag"), o => true, i => true, (o, i) => new { o, i }) + .GroupJoin( + ss.Set(), + oo => oo.o.Members.FirstOrDefault(v => v.Tag == oo.i), + ii => ii, + (k, g) => new + { + k.o, + k.i, + value = g.OrderBy(gg => gg.FullName).FirstOrDefault() + }) + .Select(r => new { r.o.Id, TagId = r.i.Id }), + elementSorter: e => (e.Id, e.TagId))); } [ConditionalTheory(Skip = "Issue#16314")] diff --git a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs index 1db1b95206d..f766c0458df 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; using System.Linq.Expressions; @@ -12,6 +13,7 @@ using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.TestModels.Northwind; using Microsoft.EntityFrameworkCore.TestUtilities; +using Microsoft.EntityFrameworkCore.Utilities; using Xunit; #pragma warning disable RCS1202 // Avoid NullReferenceException. @@ -491,6 +493,29 @@ public virtual Task Entity_equality_orderby_descending_composite_key(bool async) assertOrder: true); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_orderby_subquery(bool async) + { + return AssertQuery( + async, + ss => ss.Set().OrderBy(c => c.Orders.FirstOrDefault()), + ss => ss.Set().OrderBy(c => c.Orders.FirstOrDefault() == null ? (int?)null : c.Orders.FirstOrDefault().OrderID), + entryCount: 91, + assertOrder: true); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_orderby_descending_subquery_composite_key(bool async) + { + return AssertQuery( + async, + ss => ss.Set().OrderByDescending(o => o.OrderDetails.FirstOrDefault()), + entryCount: 830, + assertOrder: true); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Queryable_simple(bool async) @@ -5327,7 +5352,7 @@ select g.OrderByDescending(x => x.OrderID).ToList(), protected internal uint ClientEvalSelector(Order order) => order.EmployeeID % 10 ?? 0; - [ConditionalTheory] + [ConditionalTheory(Skip = "Issue#20445")] [MemberData(nameof(IsAsyncData))] public virtual Task Collection_navigation_equal_to_null_for_subquery(bool async) { @@ -5349,7 +5374,7 @@ public virtual Task Dependent_to_principal_navigation_equal_to_null_for_subquery entryCount: 2); } - [ConditionalTheory] + [ConditionalTheory(Skip = "Issue#20445")] [MemberData(nameof(IsAsyncData))] public virtual Task Collection_navigation_equality_rewrite_for_subquery(bool async) { @@ -5763,5 +5788,30 @@ public virtual Task Checked_context_with_case_to_same_nullable_type_does_not_fai ); } } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_with_null_coalesce_client_side(bool async) + { + var a = new Customer { CustomerID = "ALFKI" }; + var b = a; + + return AssertQuery( + async, + ss => ss.Set().Where(c => c == (a ?? b)), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_contains_with_list_of_null(bool async) + { + var customers = new List { null, new Customer { CustomerID = "ALFKI" } }; + + return AssertQuery( + async, + ss => ss.Set().Where(c => customers.Contains(c)), + entryCount: 1); + } } } diff --git a/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs index 0a46cab8c89..dfefc837c10 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs @@ -2119,7 +2119,7 @@ public virtual Task Where_Queryable_ToList_Count_member(bool async) elementAsserter: (e, a) => AssertCollection(e, a)); } - [ConditionalTheory(Skip = "Issue#19431")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Where_Queryable_ToArray_Length_member(bool async) { @@ -2234,7 +2234,7 @@ public virtual Task Where_collection_navigation_ToList_Count_member(bool async) elementAsserter: (e, a) => AssertCollection(e, a)); } - [ConditionalTheory(Skip = "Issue#19431")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Where_collection_navigation_ToArray_Length_member(bool async) { diff --git a/test/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs index c73362b9cc1..42d912258f8 100644 --- a/test/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs @@ -534,7 +534,7 @@ public virtual Task Where_collection_navigation_ToList_Count_member(bool async) elementAsserter: (e, a) => AssertCollection(e, a)); } - [ConditionalTheory(Skip = "Issue#19431")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Where_collection_navigation_ToArray_Length_member(bool async) { diff --git a/test/EFCore.SqlServer.FunctionalTests/LazyLoadProxySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/LazyLoadProxySqlServerTest.cs index 9350bb0e3c5..f31128b0fd6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/LazyLoadProxySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/LazyLoadProxySqlServerTest.cs @@ -446,6 +446,21 @@ FROM [Child] AS [c] ignoreLineEndingDifferences: true); } + public override async Task Entity_equality_with_proxy_parameter(bool async) + { + await base.Entity_equality_with_proxy_parameter(async); + + Assert.Equal( + @"@__entity_equality_called_0_Id='707' (Nullable = true) + +SELECT [c].[Id], [c].[ParentId] +FROM [Child] AS [c] +LEFT JOIN [Parent] AS [p] ON [c].[ParentId] = [p].[Id] +WHERE [p].[Id] = @__entity_equality_called_0_Id", + Sql, + ignoreLineEndingDifferences: true); + } + protected override void ClearLog() => Fixture.TestSqlLoggerFactory.Clear(); protected override void RecordLog() => Sql = Fixture.TestSqlLoggerFactory.Sql; diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs index a7cb439b0d3..c1157cbeeb8 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs @@ -948,7 +948,7 @@ public override async Task Select_null_propagation_negative3(bool async) AssertSql( @"SELECT [g0].[Nickname], CASE - WHEN [g0].[Nickname] IS NOT NULL THEN CASE + WHEN [g0].[Nickname] IS NOT NULL AND [g0].[SquadId] IS NOT NULL THEN CASE WHEN [g0].[LeaderNickname] IS NOT NULL THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END @@ -965,7 +965,7 @@ public override async Task Select_null_propagation_negative4(bool async) AssertSql( @"SELECT CASE - WHEN [g0].[Nickname] IS NOT NULL THEN CAST(1 AS bit) + WHEN [g0].[Nickname] IS NOT NULL AND [g0].[SquadId] IS NOT NULL THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END, [g0].[Nickname] FROM [Gears] AS [g] @@ -979,7 +979,7 @@ public override async Task Select_null_propagation_negative5(bool async) AssertSql( @"SELECT CASE - WHEN [g0].[Nickname] IS NOT NULL THEN CAST(1 AS bit) + WHEN [g0].[Nickname] IS NOT NULL AND [g0].[SquadId] IS NOT NULL THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END, [g0].[Nickname] FROM [Gears] AS [g] @@ -1190,7 +1190,7 @@ public override async Task Select_Where_Navigation_Null(bool async) @"SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[Note] FROM [Tags] AS [t] LEFT JOIN [Gears] AS [g] ON ([t].[GearNickName] = [g].[Nickname]) AND ([t].[GearSquadId] = [g].[SquadId]) -WHERE [g].[Nickname] IS NULL"); +WHERE [g].[Nickname] IS NULL OR [g].[SquadId] IS NULL"); } public override async Task Select_Where_Navigation_Null_Reverse(bool async) @@ -1201,7 +1201,7 @@ public override async Task Select_Where_Navigation_Null_Reverse(bool async) @"SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[Note] FROM [Tags] AS [t] LEFT JOIN [Gears] AS [g] ON ([t].[GearNickName] = [g].[Nickname]) AND ([t].[GearSquadId] = [g].[SquadId]) -WHERE [g].[Nickname] IS NULL"); +WHERE [g].[Nickname] IS NULL OR [g].[SquadId] IS NULL"); } public override async Task Select_Where_Navigation_Scalar_Equals_Navigation_Scalar_Projected(bool async) @@ -3476,7 +3476,7 @@ public override async Task Projecting_nullable_bool_in_conditional_works(bool as AssertSql( @"SELECT CASE - WHEN [g].[Nickname] IS NOT NULL THEN [g].[HasSoulPatch] + WHEN [g].[Nickname] IS NOT NULL AND [g].[SquadId] IS NOT NULL THEN [g].[HasSoulPatch] ELSE CAST(0 AS bit) END AS [Prop] FROM [Tags] AS [t] @@ -5892,7 +5892,7 @@ public override async Task Left_join_projection_using_conditional_tracking(bool AssertSql( @"SELECT CASE - WHEN [g0].[Nickname] IS NULL THEN CAST(1 AS bit) + WHEN [g0].[Nickname] IS NULL OR [g0].[SquadId] IS NULL THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END, [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank], [g0].[Nickname], [g0].[SquadId], [g0].[AssignedCityName], [g0].[CityOfBirthName], [g0].[Discriminator], [g0].[FullName], [g0].[HasSoulPatch], [g0].[LeaderNickname], [g0].[LeaderSquadId], [g0].[Rank] FROM [Gears] AS [g] @@ -6123,7 +6123,7 @@ public override async Task Accessing_property_of_optional_navigation_in_child_pr AssertSql( @"SELECT CASE - WHEN [g].[Nickname] IS NOT NULL THEN CAST(1 AS bit) + WHEN [g].[Nickname] IS NOT NULL AND [g].[SquadId] IS NOT NULL THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END, [t].[Id], [t0].[Nickname], [t0].[Id] FROM [Tags] AS [t] diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs index 75783a7f824..58ad0bd968d 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs @@ -330,6 +330,35 @@ FROM [Order Details] AS [o] ORDER BY [o].[OrderID] DESC, [o].[ProductID] DESC"); } + public override async Task Entity_equality_orderby_subquery(bool async) + { + await base.Entity_equality_orderby_subquery(async); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +ORDER BY ( + SELECT TOP(1) [o].[OrderID] + FROM [Orders] AS [o] + WHERE [c].[CustomerID] = [o].[CustomerID])"); + } + + public override async Task Entity_equality_orderby_descending_subquery_composite_key(bool async) + { + await base.Entity_equality_orderby_descending_subquery_composite_key(async); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +ORDER BY ( + SELECT TOP(1) [o0].[OrderID] + FROM [Order Details] AS [o0] + WHERE [o].[OrderID] = [o0].[OrderID]) DESC, ( + SELECT TOP(1) [o1].[ProductID] + FROM [Order Details] AS [o1] + WHERE [o].[OrderID] = [o1].[OrderID]) DESC"); + } + public override async Task Default_if_empty_top_level(bool async) { await base.Default_if_empty_top_level(async); @@ -4996,6 +5025,28 @@ public override async Task Checked_context_with_case_to_same_nullable_type_does_ FROM [Order Details] AS [o]"); } + public override async Task Entity_equality_with_null_coalesce_client_side(bool async) + { + await base.Entity_equality_with_null_coalesce_client_side(async); + + AssertSql( + @"@__entity_equality_p_0_CustomerID='ALFKI' (Size = 5) (DbType = StringFixedLength) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] = @__entity_equality_p_0_CustomerID"); + } + + public override async Task Entity_equality_contains_with_list_of_null(bool async) + { + await base.Entity_equality_contains_with_list_of_null(async); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] IN (N'ALFKI')"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs index fd392709b44..e429904f925 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs @@ -1904,7 +1904,15 @@ public override async Task Where_Queryable_ToArray_Length_member(bool async) { await base.Where_Queryable_ToArray_Length_member(async); - AssertSql(" "); + AssertSql( + @"SELECT [c].[CustomerID], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Customers] AS [c] +LEFT JOIN [Orders] AS [o] ON [c].[CustomerID] = [o].[CustomerID] +WHERE ( + SELECT COUNT(*) + FROM [Orders] AS [o0] + WHERE [o0].[CustomerID] = [c].[CustomerID]) = 0 +ORDER BY [c].[CustomerID], [o].[OrderID]"); } public override async Task Where_collection_navigation_ToList_Count(bool async) @@ -2012,7 +2020,15 @@ public override async Task Where_collection_navigation_ToArray_Length_member(boo { await base.Where_collection_navigation_ToArray_Length_member(async); - AssertSql(" "); + AssertSql( + @"SELECT [o].[OrderID], [o0].[OrderID], [o0].[ProductID], [o0].[Discount], [o0].[Quantity], [o0].[UnitPrice] +FROM [Orders] AS [o] +LEFT JOIN [Order Details] AS [o0] ON [o].[OrderID] = [o0].[OrderID] +WHERE ([o].[OrderID] < 10300) AND (( + SELECT COUNT(*) + FROM [Order Details] AS [o1] + WHERE [o].[OrderID] = [o1].[OrderID]) = 0) +ORDER BY [o].[OrderID], [o0].[OrderID], [o0].[ProductID]"); } private void AssertSql(params string[] expected) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs index 4bc3334d2a1..020c7c643ca 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs @@ -23,7 +23,7 @@ public override async Task Query_with_owned_entity_equality_operator(bool async) @"SELECT [o].[Id], [o].[Discriminator], [o].[Name], [o].[PersonAddress_AddressLine], [o].[PersonAddress_PlaceType], [o].[PersonAddress_ZipCode], [o].[PersonAddress_Country_Name], [o].[PersonAddress_Country_PlanetId], [o].[BranchAddress_BranchName], [o].[BranchAddress_PlaceType], [o].[BranchAddress_Country_Name], [o].[BranchAddress_Country_PlanetId], [o].[LeafAAddress_LeafType], [o].[LeafAAddress_PlaceType], [o].[Id], [o].[LeafAAddress_Country_Name], [o].[LeafAAddress_Country_PlanetId], [t].[Id], [o1].[ClientId], [o1].[Id], [o1].[OrderDate] FROM [OwnedPerson] AS [o] CROSS JOIN ( - SELECT [o0].[Id], [o0].[Discriminator], [o0].[Name] + SELECT [o0].[Id], [o0].[Discriminator], [o0].[Name], [o0].[LeafBAddress_LeafBType], [o0].[LeafBAddress_PlaceType] FROM [OwnedPerson] AS [o0] WHERE [o0].[Discriminator] = N'LeafB' ) AS [t] @@ -556,7 +556,15 @@ public override async Task Where_collection_navigation_ToArray_Length_member(boo { await base.Where_collection_navigation_ToArray_Length_member(async); - AssertSql(" "); + AssertSql( + @"SELECT [o].[Id], [o0].[ClientId], [o0].[Id], [o0].[OrderDate] +FROM [OwnedPerson] AS [o] +LEFT JOIN [Order] AS [o0] ON ([o].[Id] = [o0].[ClientId]) AND ([o].[Id] = [o0].[ClientId]) +WHERE ( + SELECT COUNT(*) + FROM [Order] AS [o1] + WHERE [o].[Id] = [o1].[ClientId]) = 0 +ORDER BY [o].[Id], [o0].[ClientId], [o0].[Id]"); } public override async Task Can_query_on_indexer_properties(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index bd9249a09c8..d7f7da9dc3c 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -6960,7 +6960,13 @@ static void AssertCustomerView( } AssertSql( - @"SELECT [c].[Id], [c].[Name], [c0].[Id], [c0].[CustomerId], [c0].[Name] + @"SELECT [c].[Id], [c].[Name], CASE + WHEN [c0].[Id] IS NOT NULL THEN [c0].[Id] + ELSE NULL +END AS [CustomerMembershipId], CASE + WHEN [c0].[Id] IS NOT NULL THEN [c0].[Name] + ELSE N'' +END AS [CustomerMembershipName] FROM [Customers] AS [c] LEFT JOIN [CustomerMemberships] AS [c0] ON [c].[Id] = [c0].[CustomerId]"); }