diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index c27e6cebbc7..f731d863d8b 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -95,38 +95,18 @@ protected override Expression VisitExtension(Expression node) /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override Expression VisitMember(MemberExpression memberExpression) - { - var innerExpression = Visit(memberExpression.Expression); - - if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), out var result)) - { - return result; - } - - return TranslationFailed(memberExpression.Expression, innerExpression) - ? null - : _memberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type); - } + => TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result) + ? result + : TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression) + ? null + : _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression) { - Type convertedType = null; - if (source is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - if (unaryExpression.Type != typeof(object)) - { - convertedType = unaryExpression.Type; - } - - source = unaryExpression.Operand; - } - + source = Visit(source.UnwrapTypeConversion(out var convertedType)); if (source is EntityProjectionExpression entityProjectionExpression) { - if (convertedType != null - && convertedType.IsInterface - && convertedType.IsAssignableFrom(entityProjectionExpression.Type)) + if (convertedType == null) { convertedType = entityProjectionExpression.Type; } @@ -162,13 +142,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result) + return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null; } - var @object = Visit(methodCallExpression.Object); - if (TranslationFailed(methodCallExpression.Object, @object)) + if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) { return null; } @@ -176,16 +155,16 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; for (var i = 0; i < arguments.Length; i++) { - var argument = Visit(methodCallExpression.Arguments[i]); - if (TranslationFailed(methodCallExpression.Arguments[i], argument)) + var argument = methodCallExpression.Arguments[i]; + if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) { return null; } - arguments[i] = (SqlExpression)argument; + arguments[i] = sqlArgument; } - return _methodCallTranslatorProvider.Translate(_model, (SqlExpression)@object, methodCallExpression.Method, arguments); + return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); } private static Expression TryRemoveImplicitConvert(Expression expression) @@ -240,17 +219,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) left = Visit(left); right = Visit(right); - if (TranslationFailed(binaryExpression.Left, left) - || TranslationFailed(binaryExpression.Right, right)) - { - return null; - } - - return _sqlExpressionFactory.MakeBinary( - binaryExpression.NodeType, - (SqlExpression)left, - (SqlExpression)right, - null); + return TranslationFailed(binaryExpression.Left, left, out var sqlLeft) + || TranslationFailed(binaryExpression.Right, right, out var sqlRight) + ? null + : _sqlExpressionFactory.MakeBinary( + binaryExpression.NodeType, + sqlLeft, + sqlRight, + null); } /// @@ -265,14 +241,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional var ifTrue = Visit(conditionalExpression.IfTrue); var ifFalse = Visit(conditionalExpression.IfFalse); - if (TranslationFailed(conditionalExpression.Test, test) - || TranslationFailed(conditionalExpression.IfTrue, ifTrue) - || TranslationFailed(conditionalExpression.IfFalse, ifFalse)) - { - return null; - } - - return _sqlExpressionFactory.Condition((SqlExpression)test, (SqlExpression)ifTrue, (SqlExpression)ifFalse); + return TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse) + ? null + : _sqlExpressionFactory.Condition(sqlTest, sqlIfTrue, sqlIfFalse); } /// @@ -285,18 +258,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) { var operand = Visit(unaryExpression.Operand); - if (operand is EntityProjectionExpression) - { - return unaryExpression.Update(operand); - } - - if (TranslationFailed(unaryExpression.Operand, operand)) + if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) { return null; } - var sqlOperand = (SqlExpression)operand; - switch (unaryExpression.NodeType) { case ExpressionType.Not: @@ -334,7 +300,9 @@ private SqlConstantExpression GetConstantOrNull(Expression expression) private static bool CanEvaluate(Expression expression) { +#pragma warning disable IDE0066 // Convert switch statement to expression switch (expression) +#pragma warning restore IDE0066 // Convert switch statement to expression { case ConstantExpression constantExpression: return true; @@ -450,7 +418,17 @@ protected override Expression VisitExtension(Expression extensionExpression) } [DebuggerStepThrough] - private bool TranslationFailed(Expression original, Expression translation) - => original != null && !(translation is SqlExpression); + private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) + { + if (original != null + && !(translation is SqlExpression)) + { + castTranslation = null; + return true; + } + + castTranslation = translation as SqlExpression; + return false; + } } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index faabef15fa7..648f78182af 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -125,6 +125,15 @@ protected override Expression VisitConditional(ConditionalExpression conditional protected override Expression VisitMember(MemberExpression memberExpression) { + if (TryBindMember( + memberExpression.Expression, + MemberIdentity.Create(memberExpression.Member), + memberExpression.Type, + out var result)) + { + return result; + } + var innerExpression = Visit(memberExpression.Expression); if (memberExpression.Expression != null && innerExpression == null) @@ -132,15 +141,6 @@ protected override Expression VisitMember(MemberExpression memberExpression) return null; } - if ((innerExpression is EntityProjectionExpression - || (innerExpression is UnaryExpression innerUnaryExpression - && innerUnaryExpression.NodeType == ExpressionType.Convert - && innerUnaryExpression.Operand is EntityProjectionExpression)) - && TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type, out var result)) - { - return result; - } - var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); if (innerExpression != null && innerExpression.Type.IsNullableType() @@ -164,24 +164,12 @@ static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string mem private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result) { + source = Visit(source.UnwrapTypeConversion(out var convertedType)); result = null; - Type convertedType = null; - if (source is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - source = unaryExpression.Operand; - if (unaryExpression.Type != typeof(object)) - { - convertedType = unaryExpression.Type; - } - } - if (source is EntityProjectionExpression entityProjection) { var entityType = entityProjection.EntityType; - if (convertedType != null - && !(convertedType.IsInterface - && convertedType.IsAssignableFrom(entityType.ClrType))) + if (convertedType != null) { entityType = entityType.GetRootType().GetDerivedTypesInclusive() .FirstOrDefault(et => et.ClrType == convertedType); @@ -284,7 +272,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result)) + if (TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result)) { return result; } @@ -396,16 +384,11 @@ MethodInfo GetMethod() { Expression result; var innerExpression = ((NewArrayExpression)newValueBufferExpression.Arguments[0]).Expressions[0]; - if (innerExpression is UnaryExpression unaryExpression + result = innerExpression is UnaryExpression unaryExpression && innerExpression.NodeType == ExpressionType.Convert - && innerExpression.Type == typeof(object)) - { - result = unaryExpression.Operand; - } - else - { - result = innerExpression; - } + && innerExpression.Type == typeof(object) + ? unaryExpression.Operand + : innerExpression; return result.Type == methodCallExpression.Type ? result diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 15739249725..56caadfc873 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -828,19 +828,8 @@ protected override Expression VisitMember(MemberExpression memberExpression) { var innerExpression = Visit(memberExpression.Expression); - if (innerExpression is EntityShaperExpression - || (innerExpression is UnaryExpression innerUnaryExpression - && innerUnaryExpression.NodeType == ExpressionType.Convert - && innerUnaryExpression.Operand is EntityShaperExpression)) - { - var collectionNavigation = Expand(innerExpression, MemberIdentity.Create(memberExpression.Member)); - if (collectionNavigation != null) - { - return collectionNavigation; - } - } - - return memberExpression.Update(innerExpression); + return TryExpand(innerExpression, MemberIdentity.Create(memberExpression.Member)) + ?? memberExpression.Update(innerExpression); } protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) @@ -848,19 +837,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var navigationName)) { source = Visit(source); - if (source is EntityShaperExpression - || (source is UnaryExpression innerUnaryExpression - && innerUnaryExpression.NodeType == ExpressionType.Convert - && innerUnaryExpression.Operand is EntityShaperExpression)) - { - var collectionNavigation = Expand(source, MemberIdentity.Create(navigationName)); - if (collectionNavigation != null) - { - return collectionNavigation; - } - } - return methodCallExpression.Update(null, new[] { source, methodCallExpression.Arguments[1] }); + return TryExpand(source, MemberIdentity.Create(navigationName)) + ?? methodCallExpression.Update(null, new[] { source, methodCallExpression.Arguments[1] }); } return base.VisitMethodCall(methodCallExpression); @@ -871,19 +850,9 @@ protected override Expression VisitExtension(Expression extensionExpression) ? extensionExpression : base.VisitExtension(extensionExpression); - private Expression Expand(Expression source, MemberIdentity member) + private Expression TryExpand(Expression source, MemberIdentity member) { - Type convertedType = null; - if (source is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - source = unaryExpression.Operand; - if (unaryExpression.Type != typeof(object)) - { - convertedType = unaryExpression.Type; - } - } - + source = source.UnwrapTypeConversion(out var convertedType); if (!(source is EntityShaperExpression entityShaperExpression)) { return null; @@ -1016,17 +985,7 @@ private ShapedQueryExpression TranslateScalarAggregate( return null; } - MethodInfo getMethod() - => methodName switch - { - nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType), - nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType), - nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType), - nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType), - _ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."), - }; - - var method = getMethod(); + var method = GetMethod(); method = method.GetGenericArguments().Length == 2 ? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) : method.MakeGenericMethod(typeof(ValueBuffer)); @@ -1040,6 +999,16 @@ MethodInfo getMethod() source.ShaperExpression = inMemoryQueryExpression.GetSingleScalarProjection(); return source; + + MethodInfo GetMethod() + => methodName switch + { + nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType), + nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType), + nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType), + nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType), + _ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."), + }; } private ShapedQueryExpression TranslateSingleResultOperator( diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index ddf9f1dbd8e..d9c615f99d7 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -986,19 +986,8 @@ protected override Expression VisitMember(MemberExpression memberExpression) { var innerExpression = Visit(memberExpression.Expression); - if (innerExpression is EntityShaperExpression - || (innerExpression is UnaryExpression innerUnaryExpression - && innerUnaryExpression.NodeType == ExpressionType.Convert - && innerUnaryExpression.Operand is EntityShaperExpression)) - { - var collectionNavigation = Expand(innerExpression, MemberIdentity.Create(memberExpression.Member)); - if (collectionNavigation != null) - { - return collectionNavigation; - } - } - - return memberExpression.Update(innerExpression); + return TryExpand(innerExpression, MemberIdentity.Create(memberExpression.Member)) + ?? memberExpression.Update(innerExpression); } protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) @@ -1006,19 +995,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var navigationName)) { source = Visit(source); - if (source is EntityShaperExpression - || (source is UnaryExpression innerUnaryExpression - && innerUnaryExpression.NodeType == ExpressionType.Convert - && innerUnaryExpression.Operand is EntityShaperExpression)) - { - var collectionNavigation = Expand(source, MemberIdentity.Create(navigationName)); - if (collectionNavigation != null) - { - return collectionNavigation; - } - } - return methodCallExpression.Update(null, new[] { source, methodCallExpression.Arguments[1] }); + return TryExpand(source, MemberIdentity.Create(navigationName)) + ?? methodCallExpression.Update(null, new[] { source, methodCallExpression.Arguments[1] }); } return base.VisitMethodCall(methodCallExpression); @@ -1029,19 +1008,9 @@ protected override Expression VisitExtension(Expression extensionExpression) ? extensionExpression : base.VisitExtension(extensionExpression); - private Expression Expand(Expression source, MemberIdentity member) + private Expression TryExpand(Expression source, MemberIdentity member) { - Type convertedType = null; - if (source is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - source = unaryExpression.Operand; - if (unaryExpression.Type != typeof(object)) - { - convertedType = unaryExpression.Type; - } - } - + source = source.UnwrapTypeConversion(out var convertedType); if (!(source is EntityShaperExpression entityShaperExpression)) { return null; diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index d3d1b07e291..0a4016ddecc 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -186,43 +186,20 @@ protected override Expression VisitExtension(Expression node) } protected override Expression VisitMember(MemberExpression memberExpression) - { - var innerExpression = Visit(memberExpression.Expression); - - if ((innerExpression is EntityProjectionExpression - || (innerExpression is UnaryExpression innerUnaryExpression - && innerUnaryExpression.NodeType == ExpressionType.Convert - && innerUnaryExpression.Operand is EntityProjectionExpression)) - && TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), out var result)) - { - return result; - } - - return TranslationFailed(memberExpression.Expression, innerExpression, out var sqlInnerExpression) - ? null - : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); - } + => TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result) + ? result + : TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression) + ? null + : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression) { + source = Visit(source.UnwrapTypeConversion(out var convertedType)); expression = null; - Type convertedType = null; - if (source is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - source = unaryExpression.Operand; - if (unaryExpression.Type != typeof(object)) - { - convertedType = unaryExpression.Type; - } - } - if (source is EntityProjectionExpression entityProjectionExpression) { var entityType = entityProjectionExpression.EntityType; - if (convertedType != null - && !(convertedType.IsInterface - && convertedType.IsAssignableFrom(entityType.ClrType))) + if (convertedType != null) { entityType = entityType.GetRootType().GetDerivedTypesInclusive() .FirstOrDefault(et => et.ClrType == convertedType); @@ -321,7 +298,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result)) + if (TryBindMember(source, MemberIdentity.Create(propertyName), out var result)) { return result; } @@ -389,7 +366,9 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return null; } +#pragma warning disable IDE0046 // Convert to conditional expression if (subquery.Tables.Count == 0 +#pragma warning restore IDE0046 // Convert to conditional expression && methodCallExpression.Method.IsGenericMethod && methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo genericMethod && (genericMethod == QueryableMethods.AnyWithoutPredicate @@ -457,18 +436,6 @@ private static Expression TryRemoveImplicitConvert(Expression expression) private Expression ConvertAnonymousObjectEqualityComparison(BinaryExpression binaryExpression) { - static Expression removeObjectConvert(Expression expression) - { - if (expression is UnaryExpression unaryExpression - && expression.Type == typeof(object) - && expression.NodeType == ExpressionType.Convert) - { - return unaryExpression.Operand; - } - - return expression; - } - var leftExpressions = ((NewArrayExpression)((NewExpression)binaryExpression.Left).Arguments[0]).Expressions; var rightExpressions = ((NewArrayExpression)((NewExpression)binaryExpression.Right).Arguments[0]).Expressions; @@ -476,8 +443,8 @@ static Expression removeObjectConvert(Expression expression) rightExpressions, (l, r) => { - l = removeObjectConvert(l); - r = removeObjectConvert(r); + l = RemoveObjectConvert(l); + r = RemoveObjectConvert(r); if (l.Type.IsNullableType()) { r = r.Type.IsNullableType() ? r : Expression.Convert(r, l.Type); @@ -490,6 +457,13 @@ static Expression removeObjectConvert(Expression expression) return Expression.Equal(l, r); }) .Aggregate((a, b) => Expression.AndAlso(a, b)); + + static Expression RemoveObjectConvert(Expression expression) + => expression is UnaryExpression unaryExpression + && expression.Type == typeof(object) + && expression.NodeType == ExpressionType.Convert + ? unaryExpression.Operand + : expression; } protected override Expression VisitBinary(BinaryExpression binaryExpression) @@ -503,17 +477,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) var left = TryRemoveImplicitConvert(binaryExpression.Left); var right = TryRemoveImplicitConvert(binaryExpression.Right); - if (TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) - || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight)) - { - return null; - } - - return _sqlExpressionFactory.MakeBinary( - binaryExpression.NodeType, - sqlLeft, - sqlRight, - null); + return TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) + || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight) + ? null + : _sqlExpressionFactory.MakeBinary( + binaryExpression.NodeType, + sqlLeft, + sqlRight, + null); } private SqlConstantExpression GetConstantOrNull(Expression expression) @@ -529,7 +500,9 @@ private SqlConstantExpression GetConstantOrNull(Expression expression) private static bool CanEvaluate(Expression expression) { +#pragma warning disable IDE0066 // Convert switch statement to expression switch (expression) +#pragma warning restore IDE0066 // Convert switch statement to expression { case ConstantExpression constantExpression: return true; @@ -596,25 +569,17 @@ protected override Expression VisitConditional(ConditionalExpression conditional var ifTrue = Visit(conditionalExpression.IfTrue); var ifFalse = Visit(conditionalExpression.IfFalse); - if (TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + return TranslationFailed(conditionalExpression.Test, test, out var sqlTest) || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) - || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)) - { - return null; - } - - return _sqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse) + ? null + : _sqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); } protected override Expression VisitUnary(UnaryExpression unaryExpression) { var operand = Visit(unaryExpression.Operand); - if (operand is EntityProjectionExpression) - { - return unaryExpression.Update(operand); - } - if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) { return null; diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs index 0aee93fa1ab..c1effc87527 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs @@ -98,25 +98,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp private Expression TryExpandNavigation(Expression root, MemberIdentity memberIdentity) { - Type convertedType = null; - var innerExpression = root; - if (innerExpression is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - innerExpression = unaryExpression.Operand; - if (unaryExpression.Type != typeof(object) - && unaryExpression.Type != innerExpression.Type) - { - convertedType = unaryExpression.Type; - } - } - + var innerExpression = root.UnwrapTypeConversion(out var convertedType); if (UnwrapEntityReference(innerExpression) is EntityReference entityReference) { var entityType = entityReference.EntityType; - if (convertedType != null - && !(convertedType.IsInterface - && convertedType.IsAssignableFrom(entityType.ClrType))) + if (convertedType != null) { entityType = entityType.GetTypesInHierarchy() .FirstOrDefault(et => et.ClrType == convertedType); diff --git a/src/Shared/ExpressionExtensions.cs b/src/Shared/ExpressionExtensions.cs index 2b85af4de31..a8b3197b0e4 100644 --- a/src/Shared/ExpressionExtensions.cs +++ b/src/Shared/ExpressionExtensions.cs @@ -13,5 +13,22 @@ public static LambdaExpression UnwrapLambdaFromQuote(this Expression expression) => (LambdaExpression)(expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote ? unary.Operand : expression); + + public static Expression UnwrapTypeConversion(this Expression expression, out Type convertedType) + { + convertedType = null; + while (expression is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert) + { + expression = unaryExpression.Operand; + if (unaryExpression.Type != typeof(object) // Ignore object conversion + && !unaryExpression.Type.IsAssignableFrom(expression.Type)) // Ignore casting to base type/interface + { + convertedType = unaryExpression.Type; + } + } + + return expression; + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index e0425971d57..f11991e2c51 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -1439,7 +1439,7 @@ join eRoot in ctx.Entities.Include(e => e.Children) on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() - // ReSharper disable once ConstantNullCoalescingCondition + // ReSharper disable once ConstantNullCoalescingCondition select new { One = 1, Coalesce = eRootJoined ?? (eVersion ?? eRootJoined) }; var result = query.ToList(); @@ -1460,7 +1460,7 @@ join eRoot in ctx.Entities on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() - // ReSharper disable once ConstantNullCoalescingCondition + // ReSharper disable once ConstantNullCoalescingCondition select new { One = eRootJoined, @@ -1486,7 +1486,7 @@ join eRoot in ctx.Entities on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() - // ReSharper disable once MergeConditionalExpression + // ReSharper disable once MergeConditionalExpression #pragma warning disable IDE0029 // Use coalesce expression select eRootJoined != null ? eRootJoined : eVersion; #pragma warning restore IDE0029 // Use coalesce expression @@ -5316,7 +5316,8 @@ private SqlServerTestStore CreateDatabase15684() context.Products.Add( new Product15684 { - Name = "Apple", Category = new Category15684 { Name = "Fruit", Status = CategoryStatus15684.Active } + Name = "Apple", + Category = new Category15684 { Name = "Fruit", Status = CategoryStatus15684.Active } }); context.Products.Add(new Product15684 { Name = "Bike" }); @@ -5908,8 +5909,14 @@ public virtual void Expression_tree_constructed_via_interface_works_17276() var query = List17276(context.RemovableEntities); AssertSql( - @"SELECT [r].[Id], [r].[IsRemoved], [r].[Removed], [r].[RemovedByUser] + @"SELECT [r].[Id], [r].[IsRemoved], [r].[Removed], [r].[RemovedByUser], [t].[Id], [t].[OwnedEntity_OwnedValue] FROM [RemovableEntities] AS [r] +LEFT JOIN ( + SELECT [r0].[Id], [r0].[OwnedEntity_OwnedValue], [r1].[Id] AS [Id0] + FROM [RemovableEntities] AS [r0] + INNER JOIN [RemovableEntities] AS [r1] ON [r0].[Id] = [r1].[Id] + WHERE [r0].[OwnedEntity_OwnedValue] IS NOT NULL +) AS [t] ON [r].[Id] = [t].[Id] WHERE [r].[IsRemoved] <> CAST(1 AS bit)"); } } @@ -5935,6 +5942,31 @@ FROM [Parents] AS [p] } } + [ConditionalFact] + public virtual void Expression_tree_constructed_via_interface_for_owned_navigation_works_17505() + { + using (CreateDatabase17276()) + { + using (var context = new MyContext17276(_options)) + { + var query = context.RemovableEntities + .Where(p => EF.Property(EF.Property(p, "OwnedEntity"), "OwnedValue") == "Abc") + .ToList(); + + AssertSql( + @"SELECT [r].[Id], [r].[IsRemoved], [r].[Removed], [r].[RemovedByUser], [t].[Id], [t].[OwnedEntity_OwnedValue] +FROM [RemovableEntities] AS [r] +LEFT JOIN ( + SELECT [r0].[Id], [r0].[OwnedEntity_OwnedValue], [r1].[Id] AS [Id0] + FROM [RemovableEntities] AS [r0] + INNER JOIN [RemovableEntities] AS [r1] ON [r0].[Id] = [r1].[Id] + WHERE [r0].[OwnedEntity_OwnedValue] IS NOT NULL +) AS [t] ON [r].[Id] = [t].[Id] +WHERE ([t].[OwnedEntity_OwnedValue] = N'Abc') AND [t].[OwnedEntity_OwnedValue] IS NOT NULL"); + } + } + } + [ConditionalFact] public virtual void Expression_tree_constructed_via_interface_works_16759() { @@ -6001,6 +6033,7 @@ private class RemovableEntity17276 : IRemovable17276 public bool IsRemoved { get; set; } public string RemovedByUser { get; set; } public DateTime? Removed { get; set; } + public OwnedEntity OwnedEntity { get; set; } } private class Parent17276 : IHasId17276 @@ -6009,11 +6042,22 @@ private class Parent17276 : IHasId17276 public RemovableEntity17276 RemovableEntity { get; set; } } + [Owned] + private class OwnedEntity : IOwned + { + public string OwnedValue { get; set; } + } + private interface IHasId17276 { T Id { get; } } + private interface IOwned + { + public string OwnedValue { get; } + } + private class Specification17276 where T : IHasId17276 { @@ -7024,7 +7068,7 @@ private class DbGame #region Issue13517 - [Fact] + [ConditionalFact] public void Query_filter_with_pk_fk_optimization_bug_13517() { using var _ = CreateDatabase13517(); @@ -7068,27 +7112,27 @@ private SqlServerTestStore CreateDatabase13517() ClearLog(); }); - public class BugEntity13517 + private class BugEntity13517 { public int Id { get; set; } public int? RefEntityId { get; set; } public BugRefEntity13517 RefEntity { get; set; } } - public class BugRefEntity13517 + private class BugRefEntity13517 { public int Id { get; set; } public bool Public { get; set; } } - public class BugEntityDto13517 + private class BugEntityDto13517 { public int Id { get; set; } public int? RefEntityId { get; set; } public BugRefEntityDto13517 RefEntity { get; set; } } - public class BugRefEntityDto13517 + private class BugRefEntityDto13517 { public int Id { get; set; } public bool Public { get; set; } @@ -7112,6 +7156,107 @@ public BugContext13517(DbContextOptions options) #endregion + #region Issue17794 + + [ConditionalFact] + public void Double_convert_interface_created_expression_tree() + { + using var _ = CreateDatabase17794(); + using var context = new BugContext17794(_options); + + var expression = HasAction17794(OfferActions17794.Accepted); + var query = context.Offers.Where(expression).Count(); + + Assert.Equal(1, query); + + AssertSql( + @"@__action_0='1' + +SELECT COUNT(*) +FROM [Offers] AS [o] +WHERE EXISTS ( + SELECT 1 + FROM [OfferActions] AS [o0] + WHERE ([o].[Id] = [o0].[OfferId]) AND ([o0].[Action] = @__action_0))"); + } + + private SqlServerTestStore CreateDatabase17794() + => CreateTestStore( + () => new BugContext17794(_options), + context => + { + context.Add(new Offer17794 + { + Actions = new List + { + new OfferAction17794 + { + Action = OfferActions17794.Accepted + } + } + }); + + context.SaveChanges(); + + ClearLog(); + }); + + private static Expression> HasAction17794(OfferActions17794 action) + where T : IOffer17794 + { + Expression> predicate = oa => oa.Action == action; + + return v => v.Actions.AsQueryable().Any(predicate); + } + + private interface IOffer17794 + { + ICollection Actions { get; set; } + } + + private class Offer17794 : IOffer17794 + { + public int Id { get; set; } + + public ICollection Actions { get; set; } + } + + private enum OfferActions17794 : int + { + Accepted = 1, + Declined = 2 + } + + private class OfferAction17794 + { + public int Id { get; set; } + + [Required] + public Offer17794 Offer { get; set; } + public int OfferId { get; set; } + + [Required] + public OfferActions17794 Action { get; set; } + } + + private class BugContext17794 : DbContext + { + public DbSet Offers { get; set; } + public DbSet OfferActions { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + + } + + public BugContext17794(DbContextOptions options) + : base(options) + { + } + } + + #endregion + private DbContextOptions _options; private SqlServerTestStore CreateTestStore(